[Mlir-commits] [mlir] [mlir][ArmSME] Switch to an attribute-based tile allocation scheme (PR #73253)
Benjamin Maxwell
llvmlistbot at llvm.org
Thu Nov 23 08:43:09 PST 2023
https://github.com/MacDue created https://github.com/llvm/llvm-project/pull/73253
This reworks the ArmSME dialect to use attributes for tile allocation. This has a number of advantages and corrects some issues with the previous approach:
* Tile allocation can now be done ASAP (i.e. immediately after `-convert-vector-to-arm-sme`)
* SSA form for control flow is now supported (e.g.`scf.for` loops that yeild tiles)
* ArmSME ops can be converted to intrinsics very late (i.e. after lowering to control flow)
* Tests are simplified by removing constants and casts
* Avoids correctness issues with representing LLVM `immargs` as MLIR values
- The tile ID on the SME intrinsics is an `immarg` (so is required to be a compile-time constant), `immargs` should be mapped to MLIR attributes (this is already the case for intrinsics in the LLVM dialect)
- Using MLIR values for `immargs` can lead to invalid LLVM IR being generated (and passes such as -cse making incorrect optimizations)
As part of this patch we bid farewell to the following operations:
```mlir
arm_sme.get_tile_id : i32
arm_sme.cast_tile_to_vector : i32 to vector<[4]x[4]xi32>
arm_sme.cast_vector_to_tile : vector<[4]x[4]xi32> to i32
```
These are now replaced with:
```mlir
// Allocates a new tile with (indeterminate) state:
arm_sme.get_tile : vector<[4]x[4]xi32>
// A placeholder operation for lowering ArmSME ops to intrinsics:
arm_sme.materialize_ssa_tile : vector<[4]x[4]xi32>
```
The new tile allocation works by operations implementing the `ArmSMETileOpInterface`. This interface says that an operation needs to be assigned a tile ID, and may conditionally allocate a new SME tile.
Operations allocate a new tile by implementing...
```c++
std::optional<arm_sme::ArmSMETileType> getAllocatedTileType()
```
...and returning what type of tile the op allocates (ZAB, ZAH, etc).
Operations that don't allocate a tile return `std::nullopt` (which is the default behaviour).
Currently the following ops are defined as allocating:
```mlir
arm_sme.get_tile
arm_sme.zero
arm_sme.tile_load
arm_sme.outerproduct // (if no accumulator is specified)
```
Allocating operations become the roots for the tile allocation pass, which currently just (naively) assigns all transitive uses of a root operation the same tile ID. However, this is enough to handle current use cases.
Once tile IDs have been allocated subsequent rewrites can forward the tile IDs to any newly operations.
>From a9595e75f75a7ad132318c4edfcb31892bba6e29 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 23 Nov 2023 14:48:32 +0000
Subject: [PATCH] [mlir][ArmSME] Switch to an attribute-based tile allocation
scheme
This reworks the ArmSME dialect to use attributes for tile allocation.
This has a number of advantages and corrects some issues with the
previous approach:
* Tile allocation can now be done ASAP (i.e. immediately after
`-convert-vector-to-arm-sme`)
* SSA form for control flow is now supported (e.g.`scf.for` loops that
yeild tiles)
* ArmSME ops can be converted to intrinsics very late (i.e. after
lowering to control flow)
* Tests are simplified by removing constants and casts
* Avoids correctness issues with representing LLVM `immargs` as MLIR
values
- The tile ID on the SME intrinsics is an `immarg` (so is required
to be a compile-time constant), `immargs` should be mapped to
MLIR attributes (this is already the case for intrinsics in the
LLVM dialect)
- Using MLIR values for `immargs` can lead to invalid LLVM IR being
generated (and passes such as -cse making incorrect optimizations)
As part of this patch we bid farewell to the following operations:
```mlir
arm_sme.get_tile_id : i32
arm_sme.cast_tile_to_vector : i32 to vector<[4]x[4]xi32>
arm_sme.cast_vector_to_tile : vector<[4]x[4]xi32> to i32
```
These are now replaced with:
```mlir
// Allocates a new tile with (indeterminate) state:
arm_sme.get_tile : vector<[4]x[4]xi32>
// A placeholder operation for lowering ArmSME ops to intrinsics:
arm_sme.materialize_ssa_tile : vector<[4]x[4]xi32>
```
The new tile allocation works by operations implementing the
`ArmSMETileOpInterface`. This interface says that an operation needs
to be assigned a tile ID, and may conditionally allocate a new SME tile.
Operations allocate a new tile by implementing...
```c++
std::optional<arm_sme::ArmSMETileType> getAllocatedTileType()
```
...and returning what type of tile the op allocates (ZAB, ZAH, etc).
Operations that don't allocate a tile return `std::nullopt` (which
is the default behaviour).
Currently the following ops are defined as allocating:
```mlir
arm_sme.get_tile
arm_sme.zero
arm_sme.tile_load
arm_sme.outerproduct // (if no accumulator is specified)
```
Allocating operations become the roots for the tile allocation pass,
which currently just (naively) assigns all transitive uses of a root
operation the same tile ID. However, this is enough to handle current
use cases.
Once tile IDs have been allocated subsequent rewrites can forward the
tile IDs to any newly operations.
---
mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h | 4 +-
.../mlir/Dialect/ArmSME/IR/ArmSMEEnums.h | 16 +
.../Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td | 52 ++-
.../mlir/Dialect/ArmSME/IR/ArmSMEOps.td | 283 ++++++------
.../mlir/Dialect/ArmSME/IR/CMakeLists.txt | 6 +
.../include/mlir/Dialect/ArmSME/Utils/Utils.h | 16 +-
.../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 197 ++++-----
.../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp | 47 +-
.../VectorToArmSME/VectorToArmSME.cpp | 23 +-
mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp | 22 +-
mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt | 1 +
.../ArmSME/Transforms/TileAllocation.cpp | 105 +++--
mlir/lib/Dialect/ArmSME/Utils/CMakeLists.txt | 2 -
mlir/lib/Dialect/ArmSME/Utils/Utils.cpp | 44 +-
.../ArmSMEToSCF/arm-sme-to-scf.mlir | 10 +-
.../test/Dialect/ArmSME/arith-ops-to-sme.mlir | 6 +-
.../Dialect/ArmSME/arm-sme-to-llvm-casts.mlir | 51 ---
mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir | 252 ++++++-----
mlir/test/Dialect/ArmSME/canonicalize.mlir | 40 +-
mlir/test/Dialect/ArmSME/cse.mlir | 36 +-
mlir/test/Dialect/ArmSME/invalid.mlir | 62 +--
mlir/test/Dialect/ArmSME/roundtrip.mlir | 193 ++-------
mlir/test/Dialect/ArmSME/tile-allocation.mlir | 376 ++++++++--------
mlir/test/Dialect/ArmSME/tile-zero-masks.mlir | 102 +----
.../Dialect/ArmSME/vector-ops-to-llvm.mlir | 410 +++++++++---------
.../Dialect/ArmSME/vector-ops-to-sme.mlir | 6 +-
.../Dialect/Linalg/CPU/ArmSME/fill-2d.mlir | 4 +-
.../Linalg/CPU/ArmSME/matmul-transpose-a.mlir | 4 +-
.../Dialect/Linalg/CPU/ArmSME/matmul.mlir | 4 +-
.../Vector/CPU/ArmSME/test-load-vertical.mlir | 4 +-
.../CPU/ArmSME/test-outerproduct-f32.mlir | 4 +-
.../CPU/ArmSME/test-outerproduct-f64.mlir | 4 +-
.../CPU/ArmSME/test-transfer-read-2d.mlir | 4 +-
.../CPU/ArmSME/test-transfer-write-2d.mlir | 4 +-
.../Vector/CPU/ArmSME/test-transpose.mlir | 4 +-
.../Dialect/Vector/CPU/ArmSME/tile_fill.mlir | 4 +-
.../Vector/CPU/ArmSME/vector-load-store.mlir | 4 +-
.../Dialect/Vector/CPU/ArmSME/vector-ops.mlir | 5 +-
mlir/test/Target/LLVMIR/arm-sme-invalid.mlir | 15 +-
mlir/test/Target/LLVMIR/arm-sme.mlir | 331 +++++++-------
mlir/tools/mlir-query/mlir-query.cpp | 6 +-
41 files changed, 1271 insertions(+), 1492 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEEnums.h
delete mode 100644 mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
index fe1f9062a37ef51..1da8e488a4c4647 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
@@ -14,6 +14,8 @@
#define MLIR_DIALECT_ARMSME_IR_ARMSME_H
#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h"
+#include "mlir/Dialect/ArmSME/Utils/Utils.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -22,7 +24,7 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
-#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h.inc"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h.inc"
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.h.inc"
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEEnums.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEEnums.h
new file mode 100644
index 000000000000000..430f3571001c8f4
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEEnums.h
@@ -0,0 +1,16 @@
+//===- ArmSMEDialect.h - Arm SME Dialect Enums ------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARMSME_ENUMS_H
+#define MLIR_DIALECT_ARMSME_ENUMS_H
+
+#include "mlir/IR/Dialect.h"
+
+#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h.inc"
+
+#endif
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
index b75918ebf2f6d9c..2a0167afa8bae9e 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
@@ -54,7 +54,10 @@ def MOPVector : ScalableVectorOfRankAndLengthAndType<[1], [16, 8, 4, 2],
}];
}
-class ArmSME_IntrOp<string mnemonic, list<int> overloadedOperands = [],
+class ArmSME_IntrOp<string mnemonic,
+ list<int> immArgPositions = [],
+ list<string> immArgAttrNames = [],
+ list<int> overloadedOperands = [],
list<Trait> traits = [], int numResults = 0,
list<int> overloadedResults = []>
: LLVM_IntrOpBase<
@@ -64,16 +67,26 @@ class ArmSME_IntrOp<string mnemonic, list<int> overloadedOperands = [],
/*list<int> overloadedResults=*/overloadedResults,
/*list<int> overloadedOperands=*/overloadedOperands,
/*list<Trait> traits=*/traits,
- /*int numResults=*/numResults>;
+ /*int numResults=*/numResults,
+ /*bit requiresAccessGroup=*/0,
+ /*bit requiresAliasAnalysis=*/0,
+ /*bit requiresFastmath=*/0,
+ /*list<int> immArgPositions=*/immArgPositions,
+ /*list<string> immArgAttrNames=*/immArgAttrNames>;
// Zero
-def LLVM_aarch64_sme_zero : ArmSME_IntrOp<"zero">,
- Arguments<(ins Arg<I32, "Tile mask">:$tile_mask)>;
+def LLVM_aarch64_sme_zero : ArmSME_IntrOp<"zero",
+ /*immArgPositions=*/[0],
+ /*immArgAttrNames=*/["tile_mask"]>,
+ Arguments<(ins Arg<I32Attr, "Tile mask">:$tile_mask)>;
// MOP's
class ArmSME_IntrMopOverloadedOp<string mnemonic>
- : ArmSME_IntrOp<mnemonic, [4]>,
- Arguments<(ins Arg<I32, "Virtual tile ID">:$tile_id,
+ : ArmSME_IntrOp<mnemonic,
+ /*immArgPositions=*/[0],
+ /*immArgAttrNames=*/["tile_id"],
+ /*overloadedOperands=*/[4]>,
+ Arguments<(ins Arg<I32Attr, "Virtual tile ID">:$tile_id,
Arg<MOPPredicate, "LHS predicate">:$lhs_predicate,
Arg<MOPPredicate, "RHS predicate">:$rhs_predicate,
Arg<MOPVector, "LHS vector operand">:$lhs_vector,
@@ -92,12 +105,17 @@ def LLVM_aarch64_sme_sumops_wide : ArmSME_IntrMopOverloadedOp<"sumops.wide">;
def LLVM_aarch64_sme_usmopa_wide : ArmSME_IntrMopOverloadedOp<"usmopa.wide">;
def LLVM_aarch64_sme_usmops_wide : ArmSME_IntrMopOverloadedOp<"usmops.wide">;
+class ArmSME_IntrLoadStoreOp<string mnemonic>
+ : ArmSME_IntrOp<mnemonic,
+ /*immArgPositions=*/[2],
+ /*immArgAttrNames=*/["tile_id"]>;
+
// Loads
class ArmSME_IntrLoadOp<string mnemonic>
- : ArmSME_IntrOp<mnemonic>,
+ : ArmSME_IntrLoadStoreOp<mnemonic>,
Arguments<(ins Arg<SVEPredicate, "Vector predicate">:$predicate,
Arg<LLVM_AnyPointer, "Load address">:$load_address,
- Arg<I32, "Virtual tile ID">:$tile_id,
+ Arg<I32Attr, "Virtual tile ID">:$tile_id,
Arg<I32, "Tile slice">:$tile_slice_index)>;
def LLVM_aarch64_sme_ld1b_horiz : ArmSME_IntrLoadOp<"ld1b.horiz">;
@@ -113,10 +131,10 @@ def LLVM_aarch64_sme_ld1q_vert : ArmSME_IntrLoadOp<"ld1q.vert">;
// Stores
class ArmSME_IntrStoreOp<string mnemonic>
- : ArmSME_IntrOp<mnemonic>,
+ : ArmSME_IntrLoadStoreOp<mnemonic>,
Arguments<(ins Arg<SVEPredicate, "Vector predicate">:$predicate,
Arg<LLVM_AnyPointer, "Store address", [MemWrite]>:$store_address,
- Arg<I32, "Virtual tile ID">:$tile_id,
+ Arg<I32Attr, "Virtual tile ID">:$tile_id,
Arg<I32, "Tile slice">:$tile_slice_index)>;
def LLVM_aarch64_sme_st1b_horiz : ArmSME_IntrStoreOp<"st1b.horiz">;
@@ -138,22 +156,28 @@ def LLVM_aarch64_sme_str
// Vector to tile slice
class LLVM_aarch64_sme_write<string direction>
- : ArmSME_IntrOp<"write." # direction, /*overloadedOperands=*/[3],
+ : ArmSME_IntrOp<"write." # direction,
+ /*immArgPositions=*/[0],
+ /*immArgAttrNames=*/["tile_id"],
+ /*overloadedOperands=*/[3],
[AllShapesMatch<["predicate", "vector"]>]>,
- Arguments<(ins Arg<I32, "Virtual tile ID">:$tile_id,
+ Arguments<(ins Arg<I32Attr, "Virtual tile ID">:$tile_id,
Arg<I32, "Tile slice">:$tile_slice_index,
Arg<SVEPredicate, "Vector predicate">:$predicate,
Arg<SVEVector, "Vector operand">:$vector)>;
// Tile slice to vector
class LLVM_aarch64_sme_read<string direction>
- : ArmSME_IntrOp<"read." # direction, /*overloadedOperands=*/[],
+ : ArmSME_IntrOp<"read." # direction,
+ /*immArgPositions=*/[2],
+ /*immArgAttrNames=*/["tile_id"],
+ /*overloadedOperands=*/[],
[AllShapesMatch<["vector", "predicate", "res"]>,
AllElementTypesMatch<["vector", "res"]>],
/*numResults=*/1, /*overloadedResults=*/[0]>,
Arguments<(ins Arg<SVEVector, "Vector operand">:$vector,
Arg<SVEPredicate, "Vector predicate">:$predicate,
- Arg<I32, "Virtual tile ID">:$tile_id,
+ Arg<I32Attr, "Virtual tile ID">:$tile_id,
Arg<I32, "Tile slice">:$tile_slice_index)>;
def LLVM_aarch64_sme_write_horiz : LLVM_aarch64_sme_write<"horiz">;
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index ba33a2826e6ca4b..abcc9b649c4a530 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -21,6 +21,99 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
+//===----------------------------------------------------------------------===//
+// ArmSME op interfaces
+//===----------------------------------------------------------------------===//
+
+def ArmSMETileType : I32EnumAttr<"ArmSMETileType", "Arm SME tile type",
+ [
+ I32EnumAttrCase<"ZAB", 0, "za.b">,
+ I32EnumAttrCase<"ZAH", 1, "za.h">,
+ I32EnumAttrCase<"ZAS", 2, "za.s">,
+ I32EnumAttrCase<"ZAD", 3, "za.d">,
+ I32EnumAttrCase<"ZAQ", 4, "za.q">,
+ ]>{
+ let cppNamespace = "mlir::arm_sme";
+ let genSpecializedAttr = 0;
+}
+
+def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
+ let description = [{
+ An interface for operations that use or allocate Arm SME tiles. These
+ operations need to be assigned a tile ID an i32 attribute, which specifies
+ which virtual tile within the ZA storage to use. The number of tiles
+ available depends on the type of the tile. This is summarized below:
+
+ | Tile Vector Types | Possible Tile IDs |
+ |-------------------------------------------------------------------------|---------------------|
+ | `vector<[16]x[16]xi8>` | 0 |
+ | `vector<[8]x[8]xi16>`, `vector<[8]x[8]xf16>`, or `vector<[8]x[8]xbf16>` | 0 and 1 |
+ | `vector<[4]x[4]xi32>` or `vector<[4]x[4]xf32>` | 0 to 3 (inclusive) |
+ | `vector<[2]x[2]xi64>` or `vector<[2]x[2]xf64>` | 0 to 7 (inclusive) |
+ | `vector<[1]x[1]xi128>` | 0 to 15 (inclusive) |
+
+ Operations that allocate a new tiles (such as arm_sme.get_tile), are used as
+ the roots for tile allocation, with all operations that (transitively)
+ depend on a root being assigned the same tile ID.
+ }];
+ let methods = [
+ InterfaceMethod<
+ "Sets the tile ID for this operation.",
+ /*returnType=*/"void",
+ /*methodName=*/"setTileId",
+ /*arguments=*/(ins "mlir::IntegerAttr":$tileId),
+ /*methodBody=*/[{}],
+ /*defaultImpl=*/ [{
+ if (!tileId)
+ return;
+ ::mlir::Operation* op = this->getOperation();
+ op->setAttr("tile_id", tileId);
+ }]
+ >,
+ InterfaceMethod<
+ "Returns the (possibly null) tile ID assigned to this operation.",
+ /*returnType=*/"mlir::IntegerAttr",
+ /*methodName=*/"getTileId",
+ /*arguments=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImpl=*/ [{
+ ::mlir::Operation* op = this->getOperation();
+ return op->getAttrOfType<mlir::IntegerAttr>("tile_id");
+ }]
+ >,
+ InterfaceMethod<
+ "The type of tile this operation allocates (or none)",
+ /*returnType=*/"std::optional<::mlir::arm_sme::ArmSMETileType>",
+ /*methodName=*/"getAllocatedTileType",
+ /*arguments=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImpl=*/ [{
+ // Do not allocate a new tile.
+ return std::nullopt;
+ }]
+ >
+ ];
+
+ let extraSharedClassDeclaration = [{
+ // A helper to create a new operation and propagate this operations tile ID.
+ template<typename T, typename... Args>
+ T createOpAndForwardTileId(::mlir::RewriterBase& rewriter, ::mlir::Location loc, Args &&...args) {
+ auto op = rewriter.create<T>(loc, std::forward<Args>(args)...);
+ if (auto tileOp = ::llvm::dyn_cast<ArmSMETileOpInterface>(op.getOperation()))
+ tileOp.setTileId($_op.getTileId());
+ return op;
+ }
+
+ // A helper to replace this operation and forward any tile ID.
+ template<typename T, typename... Args>
+ T replaceWithAndForwardTileId(::mlir::RewriterBase& rewriter, Args &&...args) {
+ auto newOp = createOpAndForwardTileId<T>(rewriter, $_op.getLoc(), std::forward<Args>(args)...);
+ rewriter.replaceOp($_op, newOp);
+ return newOp;
+ }
+ }];
+}
+
//===----------------------------------------------------------------------===//
// ArmSME type definitions
//===----------------------------------------------------------------------===//
@@ -44,7 +137,8 @@ def nxnxv2f64 : SMETileType<F64, [2, 2 ], "vector<[2]x[2]xf64>">;
def SMETile : AnyTypeOf<[nxnxv16i8, nxnxv8i16, nxnxv4i32, nxnxv2i64, nxnxv1i128,
nxnxv8f16, nxnxv8bf16, nxnxv4f32, nxnxv2f64],
- "a vector type that fits into a SME tile">
+ "a vector type that fits into a SME tile",
+ "VectorType">
{
let description = [{
Possible vector types:
@@ -66,40 +160,6 @@ def SMETile : AnyTypeOf<[nxnxv16i8, nxnxv8i16, nxnxv4i32, nxnxv2i64, nxnxv1i128,
}];
}
-def TileID : AnyTypeOf<[I8, I16, I32, I64, I128],
- "an identifier of a virtual tile (of a size) within the ZA storage">
-{
- let description = [{
- The tile ID is an 8, 16, 32, 64, or 128-bit signless integer. The value of
- the integer indicates the tile to use, and the bit size indicates the size
- of tile. The number of tiles available and the element types of those depend
- on the size. This is summarised below:
-
- | Tile ID Type | Possible Tile IDs | Tile Vector Types |
- |--------------|---------------------|-------------------------------------------------------------------------|
- | `i8` | 0 | `vector<[16]x[16]xi8>` |
- | `i16` | 0 and 1 | `vector<[8]x[8]xi16>`, `vector<[8]x[8]xf16>`, or `vector<[8]x[8]xbf16>` |
- | `i32` | 0 to 3 (inclusive) | `vector<[4]x[4]xi32>` or `vector<[4]x[4]xf32>` |
- | `i64` | 0 to 7 (inclusive) | `vector<[2]x[2]xi64>` or `vector<[2]x[2]xf64>` |
- | `i128` | 0 to 15 (inclusive) | `vector<[1]x[1]xi128>` |
- }];
-}
-
-// A type constraint that verifies the bitwidth of the scalar integer returned
-// from 'arm_sme.get_tile_id' matches the element bitwidth of a "virtual tile".
-def TileElementWidthMatchesTileID : TypesMatchWith<
- "`tile_id` has the same number of bits as elements in `vector`",
- "vector", "tile_id",
- "IntegerType::get("
- "$_self.getContext(),"
- "::llvm::isa<IntegerType>(::llvm::cast<VectorType>($_self).getElementType())"
- "? ::llvm::cast<IntegerType>("
- "::llvm::cast<VectorType>($_self).getElementType())"
- ".getWidth()"
- ": ::llvm::cast<FloatType>("
- "::llvm::cast<VectorType>($_self).getElementType())"
- ".getWidth())">;
-
class HasMatchingMaskTypeConstraint<string vector, string mask> :
OptionalTypesMatchWith<
mask # " has i1 element type and same shape as " # vector,
@@ -162,125 +222,67 @@ def ArmSME_CombiningKindAttr : EnumAttr<ArmSME_Dialect, CombiningKind,
class ArmSME_Op<string mnemonic, list<Trait> traits = []> :
Op<ArmSME_Dialect, mnemonic, traits> {}
-def CastTileToVector : ArmSME_Op<"cast_tile_to_vector", [Pure, TileElementWidthMatchesTileID]> {
- let summary = "Cast from tile id to 2-d scalable vector type";
+def GetTile : ArmSME_Op<"get_tile", [ArmSMETileOpInterface]> {
+ let summary = "Returns a SME virtual tile";
let description = [{
- A `cast_tile_to_vector` operation does a cast from a tile id to a 2-d
- scalable vector type, which represents an SME "virtual tile". This would
- normally be used when lowering operations that return "virtual tile" vector
- types to model the output. This is required to preserve dataflow as SME
- intrinsics have no return values.
+ Allocates a new SME "virtual tile" within a function. The contents of the
+ tile returned from this operation undefined.
- Example:
+ Example 1:
- Input:
```mlir
- %tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
- vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+ // Allocate an 8-bit element "virtual tile"
+ %za0_b = arm_sme.get_tile: vector<[16]x[16]xi8>
```
- After lowering `vector.load`:
+ Example 2:
+
```mlir
- %tile_id = arm_sme.get_tile_id : i32
- scf.for %vnum = %c0 to %num_vectors step %c1 {
- // ...
- "arm_sme.intr.ld1w.horiz"(%pg, %ptr, %tile_id, %vnum) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
- }
- %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
- vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+ // Allocate two 16-bit element "virtual tiles"
+ %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
+ %za1_h = arm_sme.get_tile : vector<[8]x[8]xi16>
```
- In the example above, the `vector.load` can't be replaced with an SME
- intrinsic that has no outputs since it is used by the `vector.store`.
- However, by inserting a `cast_tile_to_vector` op after the load intrinsics
- the `vector.load` can be replaced. This enables "local" rewrites on
- individual vector ops, rather than "global" rewrites that would have to
- look at the vector op uses and also lower them.
-
- Canonicalization will look through `arm_sme.cast_tile_to_vector` and fold
- the cast away if it comes from a `arm_sme.cast_vector_to_tile`.
- }];
- let arguments = (ins TileID:$tile_id);
- let results = (outs SMETile:$vector);
- let assemblyFormat =
- "$tile_id attr-dict `:` type($tile_id) `to` type($vector)";
- let hasCanonicalizeMethod = 1;
-}
-
-def CastVectorToTile : ArmSME_Op<"cast_vector_to_tile", [Pure, TileElementWidthMatchesTileID]> {
- let summary = "Cast from 2-d scalable vector type to tile id";
- let description = [{
- A `cast_vector_to_tile` operation does a cast from a 2-d scalable vector
- type, which represents an SME "virtual tile", to a tile id. This is
- required to preserve dataflow as the SME intrinsics have no return values.
-
- Example:
-
- Input:
+ Example 3:
```mlir
- %tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
- vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+ // Allocate an 128-bit element "virtual tile"
+ %za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
```
+ }];
- After lowering `vector.store`:
- ```mlir
- %tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
- scf.for %vnum = %c0 to %num_vectors step %c1 {
- // ...
- %tile_id = arm_sme.cast_vector_to_tile %tile : (vector<[4]x[4]xi32>) -> i32
- "arm_sme.intr.st1w.horiz"(%pg, %ptr, %tile_id, %vnum) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
+ let results = (outs SMETile:$tile);
+ let assemblyFormat = "attr-dict `:` type($tile)";
+
+ let extraClassDeclaration = [{
+ VectorType getTileType() {
+ return ::llvm::cast<VectorType>(getTile().getType());
}
- ```
- Canonicalization will look through `arm_sme.cast_vector_to_tile` and fold
- the cast away if it comes from a `arm_sme.cast_tile_to_vector`.
+ std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
+ return arm_sme::getSMETileType(getTileType());
+ }
}];
- let arguments = (ins SMETile:$vector);
- let results = (outs TileID:$tile_id);
- let assemblyFormat =
- "$vector attr-dict `:` type($vector) `to` type($tile_id)";
- let hasCanonicalizeMethod = 1;
}
-def GetTileID : ArmSME_Op<"get_tile_id"> {
- let summary = "Returns an SME \"virtual tile\" id";
+def MaterializeSSATile : ArmSME_Op<"materialize_ssa_tile", [Pure]> {
+ let summary = "SME tile placeholder";
let description = [{
- A `get_tile_id` operation returns a scalar integer representing an SME
- "virtual tile" id. The bitwidth of the scalar indicates the element
- bitwidth of the "virtual tile".
-
- The scope of a tile id is a function and cannot be passed or returned from
- functions.
+ A placeholder to preserve dataflow while lowering to SME intrinsics (which
+ do not take or return tile values). This operation is intended to be DCE'd
+ once all ArmSME operations have been lowered.
- Example:
- ```mlir
- // Allocate and return an 8-bit element "virtual tile" id
- %za0_b = arm_sme.get_tile_id : i8
- ```
-
- Example:
- ```
- // Allocate and return two 16-bit element "virtual tile" ids
- %za0_h = arm_sme.get_tile_id : i16
- %za1_h = arm_sme.get_tile_id : i16
- ```
-
- Example:
- ```
- // Allocate and return an 128-bit element "virtual tile" id
- %za0_q = arm_sme.get_tile_id : i128
- ```
+ This operation is not intended to be used outside of the ArmSME -> LLVM
+ conversion.
}];
-
- let results = (outs TileID:$tile_id);
- let assemblyFormat = "attr-dict `:` type($tile_id)";
+ let results = (outs SMETile:$tile);
+ let assemblyFormat = "attr-dict `:` type($tile)";
}
//
// Tile reset.
//
-def ZeroOp : ArmSME_Op<"zero", [Pure]> {
+def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface]> {
let summary = "Initialize the two-dimensional ZA array with 0s";
let results = (outs SMETile:$res);
let description = [{
@@ -303,11 +305,15 @@ def ZeroOp : ArmSME_Op<"zero", [Pure]> {
VectorType getVectorType() {
return ::llvm::cast<VectorType>(getRes().getType());
}
+ std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
+ return arm_sme::getSMETileType(getVectorType());
+ }
}];
let assemblyFormat = "attr-dict `:` type($res)";
}
def TileLoadOp : ArmSME_Op<"tile_load", [
+ ArmSMETileOpInterface,
AttrSizedOperandSegments,
OptionalTypesMatchWith<
"padding type matches element type of result",
@@ -375,6 +381,9 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
VectorType getVectorType() {
return ::llvm::cast<VectorType>(getResult().getType());
}
+ std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
+ return arm_sme::getSMETileType(getVectorType());
+ }
}];
let builders = [
@@ -394,6 +403,7 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
}
def TileStoreOp : ArmSME_Op<"tile_store", [
+ ArmSMETileOpInterface,
AttrSizedOperandSegments,
HasMatchingMaskTypeConstraint<"valueToStore", "mask">,
]> {
@@ -457,6 +467,7 @@ def TileStoreOp : ArmSME_Op<"tile_store", [
}
def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
+ ArmSMETileOpInterface,
AllTypesMatch<["tile", "result"]>, TileSliceMaskConstraint<"result", "mask">
]> {
let summary = "Tile slice load and update operation";
@@ -515,6 +526,7 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
}
def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [
+ ArmSMETileOpInterface,
TileSliceMaskConstraint<"tile", "mask">
]> {
let summary = "Tile slice store operation";
@@ -570,6 +582,7 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [
}
def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
+ ArmSMETileOpInterface,
AllTypesMatch<["tile", "result"]>,
TypesMatchWith<
"type of 'vector' matches type of 'tile' slice",
@@ -617,7 +630,8 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
}];
}
-def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
+def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [
+ ArmSMETileOpInterface,
TypesMatchWith<
"type of 'result' matches type of 'tile' slice",
"tile", "result",
@@ -670,7 +684,8 @@ class OuterProductResultTileTypeConstraint<string operand> :
"}()">;
def OuterProductOp :
- ArmSME_Op<"outerproduct", [Pure,
+ ArmSME_Op<"outerproduct", [
+ ArmSMETileOpInterface,
AttrSizedOperandSegments,
AllTypesMatch<["lhs", "rhs"]>,
HasMatchingMaskTypeConstraint<"lhs", "lhsMask">,
@@ -715,7 +730,7 @@ def OuterProductOp :
```
}];
- let arguments = (ins
+let arguments = (ins
SVEVector:$lhs, SVEVector:$rhs,
Optional<SVEPredicate>:$lhsMask,
Optional<SVEPredicate>:$rhsMask,
@@ -736,6 +751,12 @@ def OuterProductOp :
VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
+ std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
+ // The outerproduct op allocates a new tile if no accumulator is passed.
+ if (!getAcc())
+ return arm_sme::getSMETileType(getResultType());
+ return std::nullopt;
+ }
}];
}
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
index 9319a8042c93f4a..9801d8b099e3f0a 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
@@ -15,6 +15,12 @@ set(LLVM_TARGET_DEFINITIONS ArmSMEOps.td)
mlir_tablegen(ArmSMEOpsConversions.inc -gen-llvmir-conversions)
add_public_tablegen_target(MLIRArmSMEConversionsIncGen)
+# Generate op interface declarations and definitions
+set(LLVM_TARGET_DEFINITIONS ArmSMEOps.td)
+mlir_tablegen(ArmSMEOpInterfaces.h.inc -gen-op-interface-decls)
+mlir_tablegen(ArmSMEOpInterfaces.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRArmSMEOpInterfaces)
+
# Generate declarations and definitions of ArmSME intrinsic Ops
set(LLVM_TARGET_DEFINITIONS ArmSMEIntrinsicOps.td)
mlir_tablegen(ArmSMEIntrinsicOps.h.inc -gen-op-decls)
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 0941592497beaae..954371cb9d5a0b1 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -15,10 +15,11 @@
#ifndef MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
#define MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
-#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include <optional>
-namespace mlir {
-namespace arm_sme {
+namespace mlir::arm_sme {
constexpr unsigned MinStreamingVectorLengthInBits = 128;
@@ -34,13 +35,8 @@ bool isValidSMETileElementType(Type type);
/// otherwise.
bool isValidSMETileVectorType(VectorType vType);
-/// Extends or truncates `tile`, which should be an `arm_sme::GetTileID` or
-/// `arm_sme::CastVectorToTile` op returning an 8/16/32/64/128-bit scalar
-/// integer, to an i32 that can be passed as the `tile` parameter to the SME
-/// intrinsics. Or returns `tile` if already i32.
-Value castTileIDToI32(Value tile, Location loc, RewriterBase &rewriter);
+std::optional<ArmSMETileType> getSMETileType(VectorType);
-} // namespace arm_sme
-} // namespace mlir
+} // namespace mlir::arm_sme
#endif // MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index e409dc57fb020e2..4f4b090dd10e3c0 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -32,6 +32,26 @@ using namespace mlir;
namespace {
+IntegerAttr getTileIdOrError(ArmSMETileOpInterface op) {
+ auto tileId = op.getTileId();
+ if (!tileId)
+ op.emitOpError(
+ "expected tile ID to be allocated before conversion to LLVM");
+ return tileId;
+}
+
+struct GetTileConversion : public ConvertOpToLLVMPattern<arm_sme::GetTile> {
+ using ConvertOpToLLVMPattern<arm_sme::GetTile>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(arm_sme::GetTile getTile, OpAdaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<arm_sme::MaterializeSSATile>(
+ getTile, getTile.getTileType());
+ return success();
+ }
+};
+
/// Lower 'arm_sme.zero' to SME intrinsics.
///
/// BEFORE:
@@ -61,9 +81,9 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<arm_sme::ZeroOp> {
unsigned tileElementWidth =
zero.getVectorType().getElementType().getIntOrFloatBitWidth();
- // Get Tile ID for the `zero` intrinsic.
- auto tileId = rewriter.create<arm_sme::GetTileID>(
- loc, rewriter.getIntegerType(tileElementWidth));
+ auto tileId = getTileIdOrError(zero);
+ if (!tileId)
+ return failure();
// Get the base mask for tile based on the element size.
// The base mask is just the mask to zero the first tile (of a size).
@@ -93,9 +113,6 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<arm_sme::ZeroOp> {
llvm_unreachable("bad element size");
}
}();
- auto maskType = rewriter.getI32Type();
- auto baseMask = rewriter.create<arith::ConstantOp>(
- loc, maskType, rewriter.getIntegerAttr(maskType, baseMaskForSize));
// The actual mask is just the base mask shifted by the tile ID.
// This will be folded to a constant after tile allocation.
@@ -118,13 +135,13 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<arm_sme::ZeroOp> {
// * Mask -> 10001000 = (00010001 << 3)
//
// This holds for all tile sizes.
- auto tileMask = rewriter.create<arith::ShLIOp>(
- loc, baseMask, castTileIDToI32(tileId, loc, rewriter));
- rewriter.create<arm_sme::aarch64_sme_zero>(loc, tileMask);
+ int32_t zeroMask = baseMaskForSize << int32_t(tileId.getInt());
+ rewriter.create<arm_sme::aarch64_sme_zero>(
+ loc, rewriter.getI32IntegerAttr(zeroMask));
- // Create `CastTileToVectorOp` to use as the output.
- rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>(zero, zero.getType(),
- tileId);
+ // Create a placeholder op to preserve dataflow.
+ rewriter.replaceOpWithNewOp<arm_sme::MaterializeSSATile>(
+ zero, zero.getVectorType());
return success();
}
@@ -141,15 +158,9 @@ struct LoadTileSliceConversion
arm_sme::LoadTileSliceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = loadTileSliceOp.getLoc();
- auto tileType = loadTileSliceOp.getVectorType();
- auto tileElementType = tileType.getElementType();
- unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
-
- // Create 'arm_sme.cast_vector_to_tile' to get a tile ID for the tile being
- // loaded to.
- auto tile = rewriter.create<arm_sme::CastVectorToTile>(
- loc, rewriter.getIntegerType(tileElementWidth),
- loadTileSliceOp.getTile());
+ auto tileId = getTileIdOrError(loadTileSliceOp);
+ if (!tileId)
+ return failure();
Value ptr = this->getStridedElementPtr(loc, loadTileSliceOp.getMemRefType(),
adaptor.getBase(),
@@ -164,7 +175,9 @@ struct LoadTileSliceConversion
// Create all active predicate mask.
auto maskOp = loadTileSliceOp.getMask();
- auto tileI32 = castTileIDToI32(tile, loc, rewriter);
+ auto tileType = loadTileSliceOp.getVectorType();
+ auto tileElementType = tileType.getElementType();
+ unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout();
// Create 'arm_sme.intr.ld1*.(horiz|vert)' intrinsic to load ZA tile slice.
@@ -174,23 +187,23 @@ struct LoadTileSliceConversion
llvm_unreachable("unexpected element type!");
case 8:
rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(loc, maskOp, ptr,
- tileI32, tileSliceI32);
+ tileId, tileSliceI32);
break;
case 16:
rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(loc, maskOp, ptr,
- tileI32, tileSliceI32);
+ tileId, tileSliceI32);
break;
case 32:
rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(loc, maskOp, ptr,
- tileI32, tileSliceI32);
+ tileId, tileSliceI32);
break;
case 64:
rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(loc, maskOp, ptr,
- tileI32, tileSliceI32);
+ tileId, tileSliceI32);
break;
case 128:
rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(loc, maskOp, ptr,
- tileI32, tileSliceI32);
+ tileId, tileSliceI32);
break;
}
} else {
@@ -199,31 +212,30 @@ struct LoadTileSliceConversion
llvm_unreachable("unexpected element type!");
case 8:
rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(loc, maskOp, ptr,
- tileI32, tileSliceI32);
+ tileId, tileSliceI32);
break;
case 16:
rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(loc, maskOp, ptr,
- tileI32, tileSliceI32);
+ tileId, tileSliceI32);
break;
case 32:
rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(loc, maskOp, ptr,
- tileI32, tileSliceI32);
+ tileId, tileSliceI32);
break;
case 64:
rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(loc, maskOp, ptr,
- tileI32, tileSliceI32);
+ tileId, tileSliceI32);
break;
case 128:
rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(loc, maskOp, ptr,
- tileI32, tileSliceI32);
+ tileId, tileSliceI32);
break;
}
}
// The load intrinsics have no result, replace 'arm_sme.tile_load' with
- // 'arm_sme.cast_tile_to_vector' to preserve dataflow.
- rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>(loadTileSliceOp,
- tileType, tile);
+ // the input tile to preserve dataflow.
+ rewriter.replaceOp(loadTileSliceOp, loadTileSliceOp.getTile());
return success();
}
@@ -244,11 +256,9 @@ struct StoreTileSliceConversion
auto tileElementType = tileType.getElementType();
unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
- // Create 'arm_sme.cast_vector_to_tile' to get a tile ID for the vector
- // being stored.
- auto tile = rewriter.create<arm_sme::CastVectorToTile>(
- loc, rewriter.getIntegerType(tileElementWidth),
- storeTileSliceOp.getTile());
+ auto tileId = getTileIdOrError(storeTileSliceOp);
+ if (!tileId)
+ return failure();
// Create 'arm_sme.intr.st1*.horiz' intrinsic to store ZA tile slice.
Value ptr = this->getStridedElementPtr(
@@ -263,7 +273,6 @@ struct StoreTileSliceConversion
auto maskOp = storeTileSliceOp.getMask();
- Value tileI32 = castTileIDToI32(tile, loc, rewriter);
arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout();
if (layout == arm_sme::TileSliceLayout::Horizontal) {
@@ -272,23 +281,23 @@ struct StoreTileSliceConversion
llvm_unreachable("unexpected element type!");
case 8:
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_horiz>(
- storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
+ storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
break;
case 16:
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_horiz>(
- storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
+ storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
break;
case 32:
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_horiz>(
- storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
+ storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
break;
case 64:
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_horiz>(
- storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
+ storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
break;
case 128:
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_horiz>(
- storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
+ storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
break;
}
} else {
@@ -297,23 +306,23 @@ struct StoreTileSliceConversion
llvm_unreachable("unexpected element type!");
case 8:
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_vert>(
- storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
+ storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
break;
case 16:
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_vert>(
- storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
+ storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
break;
case 32:
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_vert>(
- storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
+ storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
break;
case 64:
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_vert>(
- storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
+ storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
break;
case 128:
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_vert>(
- storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
+ storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
break;
}
}
@@ -334,14 +343,10 @@ struct MoveVectorToTileSliceConversion
ConversionPatternRewriter &rewriter) const override {
auto loc = moveVectorToTileSliceOp.getLoc();
auto tileType = moveVectorToTileSliceOp.getTileType();
- auto tileElementType = tileType.getElementType();
- unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
- // Create 'arm_sme.cast_vector_to_tile' to get a tile ID for the tile being
- // loaded to.
- auto tile = rewriter.create<arm_sme::CastVectorToTile>(
- loc, rewriter.getIntegerType(tileElementWidth),
- moveVectorToTileSliceOp.getTile());
+ auto tileId = getTileIdOrError(moveVectorToTileSliceOp);
+ if (!tileId)
+ return failure();
auto tileSlice = moveVectorToTileSliceOp.getTileSliceIndex();
@@ -357,26 +362,24 @@ struct MoveVectorToTileSliceConversion
/*scalableDims=*/{true});
auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
- auto tileI32 = castTileIDToI32(tile, loc, rewriter);
-
// Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice.
switch (moveVectorToTileSliceOp.getLayout()) {
case arm_sme::TileSliceLayout::Horizontal:
rewriter.create<arm_sme::aarch64_sme_write_horiz>(
- loc, tileI32, tileSliceI32, allActiveMask,
+ loc, tileId, tileSliceI32, allActiveMask,
moveVectorToTileSliceOp.getVector());
break;
case arm_sme::TileSliceLayout::Vertical:
rewriter.create<arm_sme::aarch64_sme_write_vert>(
- loc, tileI32, tileSliceI32, allActiveMask,
+ loc, tileId, tileSliceI32, allActiveMask,
moveVectorToTileSliceOp.getVector());
break;
}
// Intrinsic has no result, replace 'arm_sme.move_vector_to_tile_slice' with
- // 'arm_sme.cast_tile_to_vector' to preserve dataflow.
- rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>(
- moveVectorToTileSliceOp, tileType, tile);
+ // the input tile to preserve dataflow.
+ rewriter.replaceOp(moveVectorToTileSliceOp,
+ moveVectorToTileSliceOp.getTile());
return success();
}
@@ -394,12 +397,11 @@ struct MoveTileSliceToVectorConversion
ConversionPatternRewriter &rewriter) const override {
auto loc = moveTileSliceToVector.getLoc();
auto sliceType = moveTileSliceToVector.getSliceType();
- auto tile = moveTileSliceToVector.getTile();
auto sliceIndex = moveTileSliceToVector.getTileSliceIndex();
- // Cast tile to i32 tile ID.
- auto tileId = rewriter.create<arm_sme::CastVectorToTile>(loc, tile);
- Value tileIdI32 = castTileIDToI32(tileId, loc, rewriter);
+ auto tileId = getTileIdOrError(moveTileSliceToVector);
+ if (!tileId)
+ return failure();
// Create an 'all true' predicate for the tile slice.
auto predicateType = sliceType.cloneWith({}, rewriter.getI1Type());
@@ -419,12 +421,12 @@ struct MoveTileSliceToVectorConversion
case arm_sme::TileSliceLayout::Horizontal:
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_horiz>(
moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
- tileIdI32, sliceIndexI32);
+ tileId, sliceIndexI32);
break;
case arm_sme::TileSliceLayout::Vertical:
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_vert>(
moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
- tileIdI32, sliceIndexI32);
+ tileId, sliceIndexI32);
break;
}
@@ -454,6 +456,10 @@ struct OuterProductOpConversion
matchAndRewrite(arm_sme::OuterProductOp outerProductOp,
arm_sme::OuterProductOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ auto tileId = getTileIdOrError(outerProductOp);
+ if (!tileId)
+ return failure();
+
auto isSupportedType = [](VectorType vectorType) {
// TODO: the FP outer product instruction variants are predicated on
// different features [1]:
@@ -498,13 +504,8 @@ struct OuterProductOpConversion
Value acc = outerProductOp.getAcc();
if (!acc)
// Initalize accumulator with zero.
- acc = rewriter.create<arm_sme::ZeroOp>(loc, resultVectorType);
-
- unsigned elementWidth = resultVectorType.getElementTypeBitWidth();
- auto tileId = rewriter.create<arm_sme::CastVectorToTile>(
- loc, rewriter.getIntegerType(elementWidth), acc);
-
- auto tileI32 = castTileIDToI32(tileId, loc, rewriter);
+ acc = outerProductOp.createOpAndForwardTileId<arm_sme::ZeroOp>(
+ rewriter, loc, resultVectorType);
Value lhsMask = outerProductOp.getLhsMask();
Value rhsMask = outerProductOp.getRhsMask();
@@ -519,13 +520,13 @@ struct OuterProductOpConversion
}
// Create 'arm_sme.intr.mopa' outer product intrinsic.
- rewriter.create<arm_sme::aarch64_sme_mopa>(loc, tileI32, lhsMask, rhsMask,
+ rewriter.create<arm_sme::aarch64_sme_mopa>(loc, tileId, lhsMask, rhsMask,
outerProductOp.getLhs(),
outerProductOp.getRhs());
- // Create `CastTileToVectorOp` to use as the output.
- rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>(
- outerProductOp, resultVectorType, tileId);
+ // The outerproduct intrinsics have no result, replace
+ // 'arm_sme.outerproduct' with the input tile to preserve dataflow.
+ rewriter.replaceOp(outerProductOp, acc);
return success();
}
@@ -557,21 +558,20 @@ struct ConvertArmSMEToLLVMPass
void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
target.addIllegalDialect<arm_sme::ArmSMEDialect>();
target.addLegalOp<
- arm_sme::GetTileID, arm_sme::CastTileToVector, arm_sme::CastVectorToTile,
- arm_sme::aarch64_sme_zero, arm_sme::aarch64_sme_str,
- arm_sme::aarch64_sme_ld1b_horiz, arm_sme::aarch64_sme_ld1h_horiz,
- arm_sme::aarch64_sme_ld1w_horiz, arm_sme::aarch64_sme_ld1d_horiz,
- arm_sme::aarch64_sme_ld1q_horiz, arm_sme::aarch64_sme_st1b_horiz,
- arm_sme::aarch64_sme_st1h_horiz, arm_sme::aarch64_sme_st1w_horiz,
- arm_sme::aarch64_sme_st1d_horiz, arm_sme::aarch64_sme_st1q_horiz,
- arm_sme::aarch64_sme_ld1b_vert, arm_sme::aarch64_sme_ld1h_vert,
- arm_sme::aarch64_sme_ld1w_vert, arm_sme::aarch64_sme_ld1d_vert,
- arm_sme::aarch64_sme_ld1q_vert, arm_sme::aarch64_sme_st1b_vert,
- arm_sme::aarch64_sme_st1h_vert, arm_sme::aarch64_sme_st1w_vert,
- arm_sme::aarch64_sme_st1d_vert, arm_sme::aarch64_sme_st1q_vert,
- arm_sme::aarch64_sme_read_horiz, arm_sme::aarch64_sme_read_vert,
- arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_write_vert,
- arm_sme::aarch64_sme_mopa>();
+ arm_sme::MaterializeSSATile, arm_sme::aarch64_sme_zero,
+ arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_ld1b_horiz,
+ arm_sme::aarch64_sme_ld1h_horiz, arm_sme::aarch64_sme_ld1w_horiz,
+ arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_ld1q_horiz,
+ arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz,
+ arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz,
+ arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_ld1b_vert,
+ arm_sme::aarch64_sme_ld1h_vert, arm_sme::aarch64_sme_ld1w_vert,
+ arm_sme::aarch64_sme_ld1d_vert, arm_sme::aarch64_sme_ld1q_vert,
+ arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert,
+ arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
+ arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
+ arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
+ arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa>();
target.addLegalDialect<arith::ArithDialect>();
target.addLegalOp<UnrealizedConversionCastOp>();
}
@@ -580,7 +580,8 @@ void mlir::populateArmSMEToLLVMConversionPatterns(
ArmSMETypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<LoadTileSliceConversion, MoveTileSliceToVectorConversion,
MoveVectorToTileSliceConversion, StoreTileSliceConversion,
- OuterProductOpConversion, ZeroOpConversion>(converter);
+ OuterProductOpConversion, ZeroOpConversion, GetTileConversion>(
+ converter);
}
std::unique_ptr<Pass> mlir::createConvertArmSMEToLLVMPass() {
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index a4dfd4b7edc77c5..541d711fbd95f29 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -87,16 +87,10 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
auto loc = tileLoadOp.getLoc();
auto tileType = tileLoadOp.getVectorType();
auto tileElementType = tileType.getElementType();
- unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
- // Create 'arm_sme.get_tile' op.
- auto tileId = rewriter.create<arm_sme::GetTileID>(
- loc, rewriter.getIntegerType(tileElementWidth));
-
- // Create `arm_sme.cast_tile_to_vector` to cast tile ID to a vector type to
- // use as input tile to 'arm_sme.load_tile_slice' ops.
- auto tile =
- rewriter.create<arm_sme::CastTileToVector>(loc, tileType, tileId);
+ // Allocate a new SME tile.
+ auto tile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTile>(
+ rewriter, loc, tileType);
// Create a loop that loads each ZA tile slice from memory.
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
@@ -128,8 +122,8 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
getMemrefIndices(tileLoadOp.getIndices(),
tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
numTileSlices, memrefIndices, loc, rewriter);
- rewriter.create<arm_sme::LoadTileSliceOp>(
- loc, tileType, tileLoadOp.getBase(), allTruePredicate, tile,
+ tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
+ rewriter, loc, tileType, tileLoadOp.getBase(), allTruePredicate, tile,
memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
rewriter.setInsertionPointAfter(forOp);
@@ -209,7 +203,8 @@ struct TileLoadOpWithMaskAndPadZeroConversion
// Initialize tile with zero to satisfy padding. Inactive cols will be
// zeroed anyway since the loads use zeroing predication. For inactive rows
// however, no load will occur so these need to be zeroed.
- auto tile = rewriter.create<arm_sme::ZeroOp>(loc, tileType);
+ auto tile = tileLoadOp.createOpAndForwardTileId<arm_sme::ZeroOp>(
+ rewriter, loc, tileType);
// Create a loop to load the active tile slices from memory.
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
@@ -226,9 +221,9 @@ struct TileLoadOpWithMaskAndPadZeroConversion
getMemrefIndices(tileLoadOp.getIndices(),
tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
upperBound, memrefIndices, loc, rewriter);
- rewriter.create<arm_sme::LoadTileSliceOp>(
- loc, tileType, tileLoadOp.getBase(), numColsOp, tile, memrefIndices,
- tileSliceIndex, tileLoadOp.getLayout());
+ tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
+ rewriter, loc, tileType, tileLoadOp.getBase(), numColsOp, tile,
+ memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
rewriter.setInsertionPointAfter(forOp);
@@ -276,7 +271,6 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
auto loc = tileLoadOp.getLoc();
auto tileType = tileLoadOp.getVectorType();
auto tileElementType = tileType.getElementType();
- unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
auto maskOp = tileLoadOp.getMask();
if (!maskOp)
@@ -304,14 +298,9 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
auto numColsI32 = rewriter.create<arith::IndexCastUIOp>(
loc, rewriter.getI32Type(), numCols);
- // Create 'arm_sme.get_tile' op.
- auto tileId = rewriter.create<arm_sme::GetTileID>(
- loc, rewriter.getIntegerType(tileElementWidth));
-
- // Create `arm_sme.cast_tile_to_vector` to cast tile ID to a vector type to
- // use as input tile to 'arm_sme.load_tile_slice' ops.
- auto tile =
- rewriter.create<arm_sme::CastTileToVector>(loc, tileType, tileId);
+ // Allocate a new SME tile.
+ auto tile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTile>(
+ rewriter, loc, tileType);
// Create a loop that loads each ZA tile slice from memory.
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
@@ -356,8 +345,8 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
/*passthru=*/pad1DOp);
// Create 'arm_sme.move_vector_to_tile_slice' to move slice into tile.
- rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
- loc, tileType, loadSlice->getResult(0), tile, tileSliceIndex,
+ tileLoadOp.createOpAndForwardTileId<arm_sme::MoveVectorToTileSliceOp>(
+ rewriter, loc, tileType, loadSlice->getResult(0), tile, tileSliceIndex,
tileLoadOp.getLayout());
rewriter.setInsertionPointAfter(forOp);
@@ -450,8 +439,9 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
getMemrefIndices(tileStoreOp.getIndices(),
tileStoreOp.getMemRefType().getRank(), tileSliceIndex,
upperBound, memrefIndices, loc, rewriter);
- rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
- tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex, maskCols,
+
+ tileStoreOp.replaceWithAndForwardTileId<arm_sme::StoreTileSliceOp>(
+ rewriter, tileStoreOp.getValueToStore(), tileSliceIndex, maskCols,
tileStoreOp.getBase(), memrefIndices, tileStoreOp.getLayout());
return success();
@@ -506,6 +496,7 @@ struct TileVectorPrintOpConversion : public OpRewritePattern<vector::PrintOp> {
rewriter.setInsertionPointToStart(forOp.getBody());
// Extract the current row from the tile.
Value rowIndex = forOp.getInductionVar();
+ // FIXME: Forward tile IDs.
auto tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>(
loc, printOp.getSource(), rowIndex);
// Print the row with a 1D vector.print.
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index f1e92d1ac9708b7..109b04ce34a88b7 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -44,20 +44,6 @@ static scf::ForOp getLoopOverTileSlices(PatternRewriter &rewriter, Location loc,
return forOp;
}
-/// Returns a tile of the given vector type.
-static arm_sme::CastTileToVector
-getSMETileAndCastToVector(PatternRewriter &rewriter, Location loc,
- VectorType type) {
- unsigned tileElementWidth = type.getElementType().getIntOrFloatBitWidth();
-
- // Create 'arm_sme.get_tile' op.
- auto tileId = rewriter.create<arm_sme::GetTileID>(
- loc, rewriter.getIntegerType(tileElementWidth));
-
- // Create `arm_sme.cast_tile_to_vector` to cast tile ID to a vector type.
- return rewriter.create<arm_sme::CastTileToVector>(loc, type, tileId);
-}
-
namespace {
/// Conversion pattern for vector.transfer_read.
@@ -267,8 +253,7 @@ struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
tileSliceType, denseAttr.getSplatValue<Attribute>());
auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
- arm_sme::CastTileToVector tile =
- getSMETileAndCastToVector(rewriter, loc, tileType);
+ auto tile = rewriter.create<arm_sme::GetTile>(loc, tileType);
auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType);
auto tileSliceIndex = forOp.getInductionVar();
@@ -330,8 +315,7 @@ struct BroadcastOpToArmSMELowering
else
return failure();
- arm_sme::CastTileToVector tile =
- getSMETileAndCastToVector(rewriter, loc, tileType);
+ auto tile = rewriter.create<arm_sme::GetTile>(loc, tileType);
// Create a loop over ZA tile slices.
auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType);
@@ -387,8 +371,7 @@ struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
Value broadcastOp1D = rewriter.create<vector::BroadcastOp>(
loc, tileSliceType, splatOp.getInput());
- arm_sme::CastTileToVector tile =
- getSMETileAndCastToVector(rewriter, loc, tileType);
+ auto tile = rewriter.create<arm_sme::GetTile>(loc, tileType);
// Next, create a loop over ZA tile slices and "move" the generated 1-d
// vector to each slice.
diff --git a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
index 9df15420b9c9b6f..29fa9085a0a963d 100644
--- a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
@@ -28,6 +28,8 @@ using namespace mlir::arm_sme;
#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.cpp.inc"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.cpp.inc"
+
#define GET_OP_CLASSES
#include "mlir/Dialect/ArmSME/IR/ArmSMEOps.cpp.inc"
@@ -54,23 +56,3 @@ void ArmSMEDialect::initialize() {
#include "mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.cpp.inc"
>();
}
-
-// cast_vector_to_tile(cast_tile_to_vector(tile_id)) -> tile_id
-LogicalResult CastVectorToTile::canonicalize(CastVectorToTile op,
- PatternRewriter &rewriter) {
- if (auto castTileToVectorOp = op.getVector().getDefiningOp<CastTileToVector>()) {
- op.replaceAllUsesWith(castTileToVectorOp.getTileId());
- return success();
- }
- return failure();
-}
-
-// cast_tile_to_vector(cast_vector_to_tile(tile)) -> tile
-LogicalResult CastTileToVector::canonicalize(CastTileToVector op,
- PatternRewriter &rewriter) {
- if (auto castVectorToTileOp = op.getTileId().getDefiningOp<CastVectorToTile>()) {
- op.replaceAllUsesWith(castVectorToTileOp.getVector());
- return success();
- }
- return failure();
-}
diff --git a/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
index 3e448ec4fb1e04d..66062335fa842a4 100644
--- a/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
@@ -15,4 +15,5 @@ add_mlir_dialect_library(MLIRArmSMEDialect
MLIRSCFDialect
MLIRSideEffectInterfaces
MLIRVectorDialect
+ MLIRArmSMEUtils
)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index e0462a6dc124131..0a65efeb266d15b 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -41,10 +41,12 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#define DEBUG_TYPE "allocate-arm-sme-tiles"
@@ -107,73 +109,85 @@ enum class TileMask : unsigned {
};
/// Returns the set of masks relevant for the given type.
-static ArrayRef<TileMask> getMasks(Type type) {
- static const SmallVector<TileMask> ZA_B_MASKS = {TileMask::kZA0B};
- static const SmallVector<TileMask> ZA_H_MASKS = {TileMask::kZA0H,
- TileMask::kZA1H};
- static const SmallVector<TileMask> ZA_S_MASKS = {
- TileMask::kZA0S, TileMask::kZA1S, TileMask::kZA2S, TileMask::kZA3S};
- static const SmallVector<TileMask> ZA_D_MASKS = {
+static ArrayRef<TileMask> getMasks(ArmSMETileType type) {
+ static constexpr std::array ZA_B_MASKS = {TileMask::kZA0B};
+ static constexpr std::array ZA_H_MASKS = {TileMask::kZA0H, TileMask::kZA1H};
+ static constexpr std::array ZA_S_MASKS = {TileMask::kZA0S, TileMask::kZA1S,
+ TileMask::kZA2S, TileMask::kZA3S};
+ static constexpr std::array ZA_D_MASKS = {
TileMask::kZA0D, TileMask::kZA1D, TileMask::kZA2D, TileMask::kZA3D,
TileMask::kZA4D, TileMask::kZA5D, TileMask::kZA6D, TileMask::kZA7D};
- static const SmallVector<TileMask> ZA_Q_MASKS = {
+ static constexpr std::array ZA_Q_MASKS = {
TileMask::kZA0Q, TileMask::kZA1Q, TileMask::kZA2Q, TileMask::kZA3Q,
TileMask::kZA4Q, TileMask::kZA5Q, TileMask::kZA6Q, TileMask::kZA7Q,
TileMask::kZA8Q, TileMask::kZA9Q, TileMask::kZA10Q, TileMask::kZA11Q,
TileMask::kZA12Q, TileMask::kZA13Q, TileMask::kZA14Q, TileMask::kZA15Q};
- switch (cast<IntegerType>(type).getWidth()) {
- default:
- llvm_unreachable("unexpected type!");
- case 8:
+ switch (type) {
+ case ArmSMETileType::ZAB:
return ZA_B_MASKS;
- case 16:
+ case ArmSMETileType::ZAH:
return ZA_H_MASKS;
- case 32:
+ case ArmSMETileType::ZAS:
return ZA_S_MASKS;
- case 64:
+ case ArmSMETileType::ZAD:
return ZA_D_MASKS;
- case 128:
+ case ArmSMETileType::ZAQ:
return ZA_Q_MASKS;
}
}
-/// Allocates a tile to 'tileID' or returns an error if there are no tiles left.
-static LogicalResult getTile(GetTileID tileIDOp, TileMask &tilesInUse,
- unsigned &tileID) {
- auto masks = getMasks(tileIDOp.getType());
- for (const auto &it : llvm::enumerate(masks)) {
- const auto tileMask = it.value();
+/// Allocates a tile to 'tileId' or returns an error if there are no tiles left.
+static FailureOr<unsigned> getTile(ArmSMETileType tileType,
+ TileMask &tilesInUse) {
+ auto masks = getMasks(tileType);
+ for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
if ((tilesInUse & tileMask) == TileMask::kNone) {
tilesInUse |= tileMask;
- tileID = it.index();
- return success();
+ return tileId;
}
}
- return tileIDOp.emitError("ran out of SME virtual tiles!");
+ return failure();
}
-struct GetTileIDConversion : public OpRewritePattern<GetTileID> {
- using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(GetTileID tileIDOp,
+struct AssignTileIDsPattern
+ : public OpInterfaceRewritePattern<ArmSMETileOpInterface> {
+ using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
+ LogicalResult matchAndRewrite(ArmSMETileOpInterface tileOp,
PatternRewriter &rewriter) const override {
- auto funcOp = tileIDOp->getParentOfType<func::FuncOp>();
+ if (tileOp.getTileId())
+ return failure();
+
+ std::optional<ArmSMETileType> tileType = tileOp.getAllocatedTileType();
+ if (!tileType)
+ return rewriter.notifyMatchFailure(tileOp, "op does not allocate a tile");
+
+ auto func = tileOp->getParentOfType<FunctionOpInterface>();
TileMask tilesInUse;
- if (auto tilesInUseAttr =
- funcOp->getAttrOfType<IntegerAttr>(kTilesInUseAttr))
+ if (auto tilesInUseAttr = func->getAttrOfType<IntegerAttr>(kTilesInUseAttr))
tilesInUse = static_cast<TileMask>(tilesInUseAttr.getInt());
else
tilesInUse = TileMask::kNone;
- unsigned tileID;
- if (failed(getTile(tileIDOp, tilesInUse, tileID)))
- return failure();
+ auto tileId = getTile(*tileType, tilesInUse);
+ if (failed(tileId))
+ return tileOp.emitError("ran out of SME virtual tiles!");
- funcOp->setAttr(kTilesInUseAttr,
- rewriter.getI32IntegerAttr((unsigned)tilesInUse));
+ func->setAttr(kTilesInUseAttr,
+ rewriter.getI32IntegerAttr((unsigned)tilesInUse));
+
+ // Find all the ops that (transitively) depend on this tile.
+ SetVector<Operation *> dependantOps;
+ getForwardSlice(tileOp.getOperation(), &dependantOps);
+
+ // Set all operations to use the same tile ID.
+ // This is a navie tile allocation scheme, but works for common cases.
+ auto tileIDAttr = rewriter.getI32IntegerAttr(*tileId);
+ tileOp.setTileId(tileIDAttr);
+ for (auto *op : dependantOps) {
+ if (auto tileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op))
+ tileOp.setTileId(tileIDAttr);
+ }
- auto tileType = tileIDOp.getType();
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(
- tileIDOp, tileType, rewriter.getIntegerAttr(tileType, tileID));
return success();
}
};
@@ -182,13 +196,14 @@ struct TileAllocationPass
: public arm_sme::impl::TileAllocationBase<TileAllocationPass> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
- ConversionTarget target(getContext());
- patterns.add<GetTileIDConversion>(patterns.getContext());
- target.addLegalOp<arith::ConstantOp>();
- target.addIllegalOp<GetTileID>();
- if (failed(applyPartialConversion(getOperation(), target,
- std::move(patterns))))
+ patterns.add<AssignTileIDsPattern>(patterns.getContext());
+ GreedyRewriteConfig config;
+ // This ensures tiles are allocated in program order.
+ config.useTopDownTraversal = true;
+ if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
+ getOperation(), std::move(patterns), config))) {
signalPassFailure();
+ }
}
};
} // namespace
diff --git a/mlir/lib/Dialect/ArmSME/Utils/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Utils/CMakeLists.txt
index da8517aaf80a9fa..ecf774a215d24f8 100644
--- a/mlir/lib/Dialect/ArmSME/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/Utils/CMakeLists.txt
@@ -5,7 +5,5 @@ add_mlir_dialect_library(MLIRArmSMEUtils
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME/Utils
LINK_LIBS PUBLIC
- MLIRArmSMEDialect
- MLIRDialect
MLIRIR
)
diff --git a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
index f17077ff8565d59..c3cdf5703bcbc28 100644
--- a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
@@ -12,24 +12,20 @@
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+namespace mlir::arm_sme {
-using namespace mlir;
-using namespace mlir::arm_sme;
-
-unsigned mlir::arm_sme::getSMETileSliceMinNumElts(Type type) {
+unsigned getSMETileSliceMinNumElts(Type type) {
assert(isValidSMETileElementType(type) && "invalid tile type!");
return MinStreamingVectorLengthInBits / type.getIntOrFloatBitWidth();
}
-bool mlir::arm_sme::isValidSMETileElementType(Type type) {
+bool isValidSMETileElementType(Type type) {
return type.isInteger(8) || type.isInteger(16) || type.isInteger(32) ||
type.isInteger(64) || type.isInteger(128) || type.isF16() ||
type.isBF16() || type.isF32() || type.isF64() || type.isF128();
}
-bool mlir::arm_sme::isValidSMETileVectorType(VectorType vType) {
+bool isValidSMETileVectorType(VectorType vType) {
if ((vType.getRank() != 2) || !vType.allDimsScalable())
return false;
@@ -37,22 +33,30 @@ bool mlir::arm_sme::isValidSMETileVectorType(VectorType vType) {
if (!isValidSMETileElementType(elemType))
return false;
- unsigned minNumElts = arm_sme::getSMETileSliceMinNumElts(elemType);
+ unsigned minNumElts = getSMETileSliceMinNumElts(elemType);
if (vType.getShape() != ArrayRef<int64_t>({minNumElts, minNumElts}))
return false;
return true;
}
-Value mlir::arm_sme::castTileIDToI32(Value tile, Location loc,
- RewriterBase &rewriter) {
- assert((isa<arm_sme::GetTileID, arm_sme::CastVectorToTile>(
- tile.getDefiningOp())) &&
- "expected ArmSME GetTileID or CastVectorToTile op!");
- unsigned tileElementWidth = tile.getType().getIntOrFloatBitWidth();
- if (tileElementWidth < 32)
- return rewriter.create<arith::ExtUIOp>(loc, rewriter.getI32Type(), tile);
- if (tileElementWidth > 32)
- return rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), tile);
- return tile;
+std::optional<ArmSMETileType> getSMETileType(VectorType type) {
+ if (!isValidSMETileVectorType(type))
+ return {};
+ switch (type.getElementTypeBitWidth()) {
+ case 8:
+ return ArmSMETileType::ZAB;
+ case 16:
+ return ArmSMETileType::ZAH;
+ case 32:
+ return ArmSMETileType::ZAS;
+ case 64:
+ return ArmSMETileType::ZAD;
+ case 128:
+ return ArmSMETileType::ZAQ;
+ default:
+ llvm_unreachable("unknown SME tile type");
+ }
}
+
+} // namespace mlir::arm_sme
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index 38a2332ffd5e0a3..fc28645a7acf7c0 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -6,8 +6,7 @@
// CHECK-LABEL: func.func @arm_sme_tile_load_hor(
// CHECK-SAME: %[[SRC:.*]]: memref<?x?xi32>) {
-// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
-// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
+// CHECK-DAG: %[[TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
@@ -16,7 +15,7 @@
// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
// CHECK-NEXT: %[[PTRUE_S:.*]] = arith.constant dense<true> : vector<[4]xi1>
// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
-// CHECK-NEXT: arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[PTRUE_S]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
+// CHECK-NEXT: arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[PTRUE_S]], %[[TILE]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
func.func @arm_sme_tile_load_hor(%src : memref<?x?xi32>) {
%c0 = arith.constant 0 : index
%tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
@@ -60,8 +59,7 @@ func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>)
// CHECK-LABEL: func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(
// CHECK-SAME: %[[SRC:.*]]: memref<?x?xi32>,
// CHECK-SAME: %[[PAD:.*]]: i32) {
-// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
-// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
+// CHECK-DAG: %[[TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
@@ -79,7 +77,7 @@ func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>)
// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
// CHECK: %[[PAD_1D:.*]] = vector.splat %[[PAD]] : vector<[4]xi32>
// CHECK: %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[MASK_1D]], %[[PAD_1D]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
-// CHECK: arm_sme.move_vector_to_tile_slice %[[LOAD_SLICE]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : vector<[4]xi32> into vector<[4]x[4]xi32>
+// CHECK: arm_sme.move_vector_to_tile_slice %[[LOAD_SLICE]], %[[TILE]], %[[TILE_SLICE_INDEX]] : vector<[4]xi32> into vector<[4]x[4]xi32>
func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(%src : memref<?x?xi32>, %pad : i32) {
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
diff --git a/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
index 12f7e7333ebc973..b8db105f9c601b4 100644
--- a/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
@@ -95,8 +95,7 @@ func.func @arith_constant_dense_2d_zero_f64() {
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[C16:.*]] = arith.constant 16 : index
// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[GET_TILE_ID:.*]] = arm_sme.get_tile_id : i8
-// CHECK: %[[TILE:.*]] = arm_sme.cast_tile_to_vector %[[GET_TILE_ID]] : i8 to vector<[16]x[16]xi8>
+// CHECK: %[[TILE:.*]] = arm_sme.get_tile : vector<[16]x[16]xi8>
// CHECK: %[[VSCALE:.*]] = vector.vscale
// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C16]] : index
// CHECK: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
@@ -115,8 +114,7 @@ func.func @arith_constant_dense_2d_nonzero_i8() {
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[GET_TILE_ID:.*]] = arm_sme.get_tile_id : i64
-// CHECK: %[[TILE:.*]] = arm_sme.cast_tile_to_vector %[[GET_TILE_ID]] : i64 to vector<[2]x[2]xf64>
+// CHECK: %[[TILE:.*]] = arm_sme.get_tile : vector<[2]x[2]xf64>
// CHECK: %[[VSCALE:.*]] = vector.vscale
// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C2]] : index
// CHECK: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir
deleted file mode 100644
index 65996e81c42d909..000000000000000
--- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir
+++ /dev/null
@@ -1,51 +0,0 @@
-// RUN: mlir-opt %s -convert-arm-sme-to-scf -convert-arm-sme-to-llvm -split-input-file | FileCheck %s
-
-// This test verifies the temporary casts that are emitted when lowering to
-// intrinsics to preserve data flow are correct. Canonicalization will remove
-// these.
-
-// CHECK-LABEL: @arm_sme_zero
-// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8
-// CHECK: arm_sme.intr.zero
-// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
-// CHECK: scf.for
-// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8> to i8
-// CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i8 to i32
-// CHECK: "arm_sme.intr.st1b.horiz"({{.*}}, {{.*}}, %[[TILE_ID_I32]], {{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_zero(%dest : memref<?x?xi8>) {
- %c0 = arith.constant 0 : index
- %tile = arm_sme.zero : vector<[16]x[16]xi8>
- arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
- return
-}
-
-// -----
-
-// CHECK-LABEL: @arm_sme_tile_load
-// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8
-// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
-// CHECK: scf.for
-// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8> to i8
-// CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i8 to i32
-// CHECK: "arm_sme.intr.ld1b.horiz"({{.*}}, {{.*}}, %[[TILE_ID_I32]], {{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
-// CHECK: }
-// CHECK: return %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8>
-func.func @arm_sme_tile_load(%dest : memref<?x?xi8>) -> vector<[16]x[16]xi8> {
- %c0 = arith.constant 0 : index
- %tile = arm_sme.tile_load %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
- return %tile : vector<[16]x[16]xi8>
-}
-
-// -----
-
-// CHECK-LABEL: @arm_sme_tile_store(
-// CHECK-SAME: %[[TILE:.*]]: vector<[16]x[16]xi8>,
-// CHECK: scf.for
-// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[16]x[16]xi8> to i8
-// CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i8 to i32
-// CHECK: "arm_sme.intr.st1b.horiz"({{.*}}, {{.*}}, %[[TILE_ID_I32]], {{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_tile_store(%tile : vector<[16]x[16]xi8>, %dest : memref<?x?xi8>) {
- %c0 = arith.constant 0 : index
- arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
- return
-}
diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
index fa62332bc3f5b17..bd88da37bdf966d 100644
--- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-arm-sme-to-llvm -cse -canonicalize -split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -allocate-arm-sme-tiles -convert-arm-sme-to-llvm -cse -canonicalize -split-input-file -verify-diagnostics | FileCheck %s
// Test conversion of ArmSME ops to LLVM intrinsics.
@@ -9,23 +9,21 @@
// CHECK-LABEL: func.func @arm_sme_load_tile_slice_hor_i8(
// CHECK-SAME: %[[SRC:.*]]: memref<?x?xi8>,
// CHECK-SAME: %[[MASK:.*]]: vector<[16]xi1>,
-// CHECK-SAME: %[[TILE:.*]]: vector<[16]x[16]xi8>,
-// CHECK-SAME: %[[TILE_SLICE_INDEX:.*]]: index) {
+// CHECK-SAME: %[[TILE_SLICE_INDEX:.*]]: index)
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64
-// CHECK: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[16]x[16]xi8> to i8
// CHECK: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[STRIDE:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[OFFSET:.*]] = llvm.mul %[[C0_I64]], %[[STRIDE]] : i64
// CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
// CHECK: %[[TILE_SLICE_INDEX_I32:.*]] = arith.index_castui %[[TILE_SLICE_INDEX]] : index to i32
-// CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
-// CHECK: "arm_sme.intr.ld1b.horiz"(%[[MASK]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_INDEX_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+// CHECK: "arm_sme.intr.ld1b.horiz"(%[[MASK]], %[[GEP]], %[[TILE_SLICE_INDEX_I32]]) <{tile_id = 0 : i32}> : (vector<[16]xi1>, !llvm.ptr, i32) -> ()
// CHECK: return
// CHECK: }
-func.func @arm_sme_load_tile_slice_hor_i8(%src : memref<?x?xi8>, %mask : vector<[16]xi1>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
+func.func @arm_sme_load_tile_slice_hor_i8(%src : memref<?x?xi8>, %mask : vector<[16]xi1>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
return
}
@@ -33,9 +31,10 @@ func.func @arm_sme_load_tile_slice_hor_i8(%src : memref<?x?xi8>, %mask : vector<
// -----
// CHECK-LABEL: @arm_sme_load_tile_slice_hor_i16
-// CHECK: "arm_sme.intr.ld1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_i16(%src : memref<?x?xi16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1h.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_hor_i16(%src : memref<?x?xi16>, %mask : vector<[8]xi1>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
return
}
@@ -43,9 +42,10 @@ func.func @arm_sme_load_tile_slice_hor_i16(%src : memref<?x?xi16>, %mask : vecto
// -----
// CHECK-LABEL: @arm_sme_load_tile_slice_hor_i32
-// CHECK: "arm_sme.intr.ld1w.horiz"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_i32(%src : memref<?x?xi32>, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1w.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_hor_i32(%src : memref<?x?xi32>, %mask : vector<[4]xi1>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
return
}
@@ -53,9 +53,10 @@ func.func @arm_sme_load_tile_slice_hor_i32(%src : memref<?x?xi32>, %mask : vecto
// -----
// CHECK-LABEL: @arm_sme_load_tile_slice_hor_i64
-// CHECK: "arm_sme.intr.ld1d.horiz"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_i64(%src : memref<?x?xi64>, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1d.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_hor_i64(%src : memref<?x?xi64>, %mask : vector<[2]xi1>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
return
}
@@ -63,9 +64,10 @@ func.func @arm_sme_load_tile_slice_hor_i64(%src : memref<?x?xi64>, %mask : vecto
// -----
// CHECK-LABEL: @arm_sme_load_tile_slice_hor_i128
-// CHECK: "arm_sme.intr.ld1q.horiz"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_i128(%src : memref<?x?xi128>, %mask : vector<[1]xi1>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1q.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[1]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_hor_i128(%src : memref<?x?xi128>, %mask : vector<[1]xi1>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
return
}
@@ -73,9 +75,10 @@ func.func @arm_sme_load_tile_slice_hor_i128(%src : memref<?x?xi128>, %mask : vec
// -----
// CHECK-LABEL: @arm_sme_load_tile_slice_hor_f16
-// CHECK: "arm_sme.intr.ld1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_f16(%src : memref<?x?xf16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1h.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_hor_f16(%src : memref<?x?xf16>, %mask : vector<[8]xi1>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
return
}
@@ -83,9 +86,10 @@ func.func @arm_sme_load_tile_slice_hor_f16(%src : memref<?x?xf16>, %mask : vecto
// -----
// CHECK-LABEL: @arm_sme_load_tile_slice_hor_bf16
-// CHECK: "arm_sme.intr.ld1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref<?x?xbf16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1h.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref<?x?xbf16>, %mask : vector<[8]xi1>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
return
}
@@ -93,9 +97,10 @@ func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref<?x?xbf16>, %mask : vec
// -----
// CHECK-LABEL: @arm_sme_load_tile_slice_hor_f32
-// CHECK: "arm_sme.intr.ld1w.horiz"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_f32(%src : memref<?x?xf32>, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1w.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_hor_f32(%src : memref<?x?xf32>, %mask : vector<[4]xi1>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
return
}
@@ -103,9 +108,10 @@ func.func @arm_sme_load_tile_slice_hor_f32(%src : memref<?x?xf32>, %mask : vecto
// -----
// CHECK-LABEL: @arm_sme_load_tile_slice_hor_f64
-// CHECK: "arm_sme.intr.ld1d.horiz"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_f64(%src : memref<?x?xf64>, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1d.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_hor_f64(%src : memref<?x?xf64>, %mask : vector<[2]xi1>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
return
}
@@ -113,9 +119,10 @@ func.func @arm_sme_load_tile_slice_hor_f64(%src : memref<?x?xf64>, %mask : vecto
// -----
// CHECK-LABEL: @arm_sme_load_tile_slice_ver_i8
-// CHECK: "arm_sme.intr.ld1b.vert"({{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %mask : vector<[16]xi1>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1b.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %mask : vector<[16]xi1>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
return
}
@@ -123,9 +130,10 @@ func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %mask : vector<
// -----
// CHECK-LABEL: @arm_sme_load_tile_slice_ver_i16
-// CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %mask : vector<[8]xi1>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
return
}
@@ -133,9 +141,10 @@ func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %mask : vecto
// -----
// CHECK-LABEL: @arm_sme_load_tile_slice_ver_i32
-// CHECK: "arm_sme.intr.ld1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1w.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %mask : vector<[4]xi1>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
return
}
@@ -143,9 +152,10 @@ func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %mask : vecto
// -----
// CHECK-LABEL: @arm_sme_load_tile_slice_ver_i64
-// CHECK: "arm_sme.intr.ld1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1d.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %mask : vector<[2]xi1>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
return
}
@@ -153,9 +163,10 @@ func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %mask : vecto
// -----
// CHECK-LABEL: @arm_sme_load_tile_slice_ver_i128
-// CHECK: "arm_sme.intr.ld1q.vert"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %mask : vector<[1]xi1>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1q.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[1]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %mask : vector<[1]xi1>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
return
}
@@ -163,9 +174,10 @@ func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %mask : vec
// -----
// CHECK-LABEL: @arm_sme_load_tile_slice_ver_f16
-// CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %mask : vector<[8]xi1>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
return
}
@@ -173,9 +185,10 @@ func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %mask : vecto
// -----
// CHECK-LABEL: @arm_sme_load_tile_slice_ver_bf16
-// CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %mask : vector<[8]xi1>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
return
}
@@ -183,9 +196,10 @@ func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %mask : vec
// -----
// CHECK-LABEL: @arm_sme_load_tile_slice_ver_f32
-// CHECK: "arm_sme.intr.ld1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1w.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %mask : vector<[4]xi1>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
return
}
@@ -193,9 +207,10 @@ func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %mask : vecto
// -----
// CHECK-LABEL: @arm_sme_load_tile_slice_ver_f64
-// CHECK: "arm_sme.intr.ld1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_f64(%src : memref<?x?xf64>, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1d.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_f64(%src : memref<?x?xf64>, %mask : vector<[2]xi1>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
return
}
@@ -207,25 +222,23 @@ func.func @arm_sme_load_tile_slice_ver_f64(%src : memref<?x?xf64>, %mask : vecto
// -----
// CHECK-LABEL: func.func @arm_sme_store_tile_slice_hor_i8(
-// CHECK-SAME: %[[TILE:.*]]: vector<[16]x[16]xi8>,
// CHECK-SAME: %[[TILE_SLICE_INDEX:.*]]: index,
// CHECK-SAME: %[[MASK:.*]]: vector<[16]xi1>,
-// CHECK-SAME: %[[DEST:.*]]: memref<?x?xi8>) {
+// CHECK-SAME: %[[DEST:.*]]: memref<?x?xi8>)
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[DEST]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64
-// CHECK: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[16]x[16]xi8> to i8
// CHECK: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[STRIDE:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[OFFSET:.*]] = llvm.mul %[[C0_I64]], %[[STRIDE]] : i64
// CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
// CHECK: %[[TILE_SLICE_INDEX_I32:.*]] = arith.index_castui %[[TILE_SLICE_INDEX]] : index to i32
-// CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
-// CHECK: "arm_sme.intr.st1b.horiz"(%[[MASK]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_INDEX_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+// CHECK: "arm_sme.intr.st1b.horiz"(%[[MASK]], %[[GEP]], %[[TILE_SLICE_INDEX_I32]]) <{tile_id = 0 : i32}> : (vector<[16]xi1>, !llvm.ptr, i32) -> ()
// CHECK: return
// CHECK: }
-func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %mask : vector<[16]xi1>, %dest : memref<?x?xi8>) -> () {
+func.func @arm_sme_store_tile_slice_hor_i8(%tile_slice_index : index, %mask : vector<[16]xi1>, %dest : memref<?x?xi8>) -> () {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
return
}
@@ -233,9 +246,10 @@ func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_s
// -----
// CHECK-LABEL: @arm_sme_store_tile_slice_hor_i16
-// CHECK: "arm_sme.intr.st1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xi16>) -> () {
+// CHECK: "arm_sme.intr.st1h.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_hor_i16(%tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xi16>) -> () {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
return
}
@@ -243,9 +257,10 @@ func.func @arm_sme_store_tile_slice_hor_i16(%tile : vector<[8]x[8]xi16>, %tile_s
// -----
// CHECK-LABEL: @arm_sme_store_tile_slice_hor_i32
-// CHECK: "arm_sme.intr.st1w.horiz"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xi32>) -> () {
+// CHECK: "arm_sme.intr.st1w.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_hor_i32(%tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xi32>) -> () {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
return
}
@@ -253,9 +268,10 @@ func.func @arm_sme_store_tile_slice_hor_i32(%tile : vector<[4]x[4]xi32>, %tile_s
// -----
// CHECK-LABEL: @arm_sme_store_tile_slice_hor_i64
-// CHECK: "arm_sme.intr.st1d.horiz"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xi64>) -> () {
+// CHECK: "arm_sme.intr.st1d.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_hor_i64(%tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xi64>) -> () {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
return
}
@@ -263,9 +279,10 @@ func.func @arm_sme_store_tile_slice_hor_i64(%tile : vector<[2]x[2]xi64>, %tile_s
// -----
// CHECK-LABEL: @arm_sme_store_tile_slice_hor_i128
-// CHECK: "arm_sme.intr.st1q.horiz"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %mask : vector<[1]xi1>, %dest : memref<?x?xi128>) -> () {
+// CHECK: "arm_sme.intr.st1q.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[1]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_hor_i128(%tile_slice_index : index, %mask : vector<[1]xi1>, %dest : memref<?x?xi128>) -> () {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
return
}
@@ -273,9 +290,10 @@ func.func @arm_sme_store_tile_slice_hor_i128(%tile : vector<[1]x[1]xi128>, %tile
// -----
// CHECK-LABEL: @arm_sme_store_tile_slice_hor_f16
-// CHECK: "arm_sme.intr.st1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xf16>) -> () {
+// CHECK: "arm_sme.intr.st1h.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_hor_f16(%tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xf16>) -> () {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
return
}
@@ -283,9 +301,10 @@ func.func @arm_sme_store_tile_slice_hor_f16(%tile : vector<[8]x[8]xf16>, %tile_s
// -----
// CHECK-LABEL: @arm_sme_store_tile_slice_hor_bf16
-// CHECK: "arm_sme.intr.st1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xbf16>) -> () {
+// CHECK: "arm_sme.intr.st1h.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_hor_bf16(%tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xbf16>) -> () {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
return
}
@@ -293,9 +312,10 @@ func.func @arm_sme_store_tile_slice_hor_bf16(%tile : vector<[8]x[8]xbf16>, %tile
// -----
// CHECK-LABEL: @arm_sme_store_tile_slice_hor_f32
-// CHECK: "arm_sme.intr.st1w.horiz"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xf32>) -> () {
+// CHECK: "arm_sme.intr.st1w.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_hor_f32(%tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xf32>) -> () {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
return
}
@@ -303,9 +323,10 @@ func.func @arm_sme_store_tile_slice_hor_f32(%tile : vector<[4]x[4]xf32>, %tile_s
// -----
// CHECK-LABEL: @arm_sme_store_tile_slice_hor_f64
-// CHECK: "arm_sme.intr.st1d.horiz"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xf64>) -> () {
+// CHECK: "arm_sme.intr.st1d.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_hor_f64(%tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xf64>) -> () {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
return
}
@@ -313,9 +334,10 @@ func.func @arm_sme_store_tile_slice_hor_f64(%tile : vector<[2]x[2]xf64>, %tile_s
// -----
// CHECK-LABEL: @arm_sme_store_tile_slice_ver_i8
-// CHECK: "arm_sme.intr.st1b.vert"({{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %mask : vector<[16]xi1>, %dest : memref<?x?xi8>) -> () {
+// CHECK: "arm_sme.intr.st1b.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_i8(%tile_slice_index : index, %mask : vector<[16]xi1>, %dest : memref<?x?xi8>) -> () {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
return
}
@@ -323,9 +345,10 @@ func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_s
// -----
// CHECK-LABEL: @arm_sme_store_tile_slice_ver_i16
-// CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xi16>) -> () {
+// CHECK: "arm_sme.intr.st1h.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_i16(%tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xi16>) -> () {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
return
}
@@ -333,9 +356,10 @@ func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_s
// -----
// CHECK-LABEL: @arm_sme_store_tile_slice_ver_i32
-// CHECK: "arm_sme.intr.st1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xi32>) -> () {
+// CHECK: "arm_sme.intr.st1w.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_i32(%tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xi32>) -> () {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
return
}
@@ -343,9 +367,10 @@ func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_s
// -----
// CHECK-LABEL: @arm_sme_store_tile_slice_ver_i64
-// CHECK: "arm_sme.intr.st1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xi64>) -> () {
+// CHECK: "arm_sme.intr.st1d.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_i64(%tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xi64>) -> () {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
return
}
@@ -353,9 +378,10 @@ func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_s
// -----
// CHECK-LABEL: @arm_sme_store_tile_slice_ver_i128
-// CHECK: "arm_sme.intr.st1q.vert"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %mask : vector<[1]xi1>, %dest : memref<?x?xi128>) -> () {
+// CHECK: "arm_sme.intr.st1q.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[1]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_i128(%tile_slice_index : index, %mask : vector<[1]xi1>, %dest : memref<?x?xi128>) -> () {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
return
}
@@ -363,9 +389,10 @@ func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile
// -----
// CHECK-LABEL: @arm_sme_store_tile_slice_ver_f16
-// CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xf16>) -> () {
+// CHECK: "arm_sme.intr.st1h.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_f16(%tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xf16>) -> () {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
return
}
@@ -373,9 +400,10 @@ func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_s
// -----
// CHECK-LABEL: @arm_sme_store_tile_slice_ver_bf16
-// CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xbf16>) -> () {
+// CHECK: "arm_sme.intr.st1h.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_bf16(%tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xbf16>) -> () {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
return
}
@@ -383,9 +411,10 @@ func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile
// -----
// CHECK-LABEL: @arm_sme_store_tile_slice_ver_f32
-// CHECK: "arm_sme.intr.st1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xf32>) -> () {
+// CHECK: "arm_sme.intr.st1w.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_f32(%tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xf32>) -> () {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
return
}
@@ -393,9 +422,10 @@ func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_s
// -----
// CHECK-LABEL: @arm_sme_store_tile_slice_ver_f64
-// CHECK: "arm_sme.intr.st1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xf64>) -> () {
+// CHECK: "arm_sme.intr.st1d.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_f64(%tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xf64>) -> () {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
return
}
@@ -407,9 +437,10 @@ func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_s
// -----
// CHECK-LABEL: @arm_sme_move_vector_to_tile_slice_hor_i32
-// CHECK: "arm_sme.intr.write.horiz"({{.*}}) : (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
-func.func @arm_sme_move_vector_to_tile_slice_hor_i32(%vector : vector<[4]xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) -> () {
+// CHECK: "arm_sme.intr.write.horiz"({{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+func.func @arm_sme_move_vector_to_tile_slice_hor_i32(%vector : vector<[4]xi32>, %tile_slice_index : index) -> () {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
return
}
@@ -417,9 +448,10 @@ func.func @arm_sme_move_vector_to_tile_slice_hor_i32(%vector : vector<[4]xi32>,
// -----
// CHECK-LABEL: @arm_sme_move_vector_to_tile_slice_ver_bf16
-// CHECK: "arm_sme.intr.write.vert"({{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
-func.func @arm_sme_move_vector_to_tile_slice_ver_bf16(%vector : vector<[8]xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) -> () {
+// CHECK: "arm_sme.intr.write.vert"({{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
+func.func @arm_sme_move_vector_to_tile_slice_ver_bf16(%vector : vector<[8]xbf16>, %tile_slice_index : index) -> () {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[8]xbf16> into vector<[8]x[8]xbf16>
return
}
@@ -431,8 +463,9 @@ func.func @arm_sme_move_vector_to_tile_slice_ver_bf16(%vector : vector<[8]xbf16>
// -----
// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_i8
-// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
-func.func @arm_sme_move_tile_slice_to_vector_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> vector<[16]xi8> {
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi8>, vector<[16]xi1>, i32) -> vector<[16]xi8>
+func.func @arm_sme_move_tile_slice_to_vector_i8(%tile_slice_index : index) -> vector<[16]xi8> {
+ %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[16]xi8> from vector<[16]x[16]xi8>
return %slice : vector<[16]xi8>
}
@@ -440,8 +473,9 @@ func.func @arm_sme_move_tile_slice_to_vector_i8(%tile : vector<[16]x[16]xi8>, %t
// -----
// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_i16
-// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
-func.func @arm_sme_move_tile_slice_to_vector_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index) -> vector<[8]xi16> {
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi16>, vector<[8]xi1>, i32) -> vector<[8]xi16>
+func.func @arm_sme_move_tile_slice_to_vector_i16(%tile_slice_index : index) -> vector<[8]xi16> {
+ %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[8]xi16> from vector<[8]x[8]xi16>
return %slice : vector<[8]xi16>
}
@@ -449,8 +483,9 @@ func.func @arm_sme_move_tile_slice_to_vector_i16(%tile : vector<[8]x[8]xi16>, %t
// -----
// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_i32
-// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
-func.func @arm_sme_move_tile_slice_to_vector_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index) -> vector<[4]xi32> {
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xi32>, vector<[4]xi1>, i32) -> vector<[4]xi32>
+func.func @arm_sme_move_tile_slice_to_vector_i32(%tile_slice_index : index) -> vector<[4]xi32> {
+ %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[4]xi32> from vector<[4]x[4]xi32>
return %slice : vector<[4]xi32>
}
@@ -458,8 +493,9 @@ func.func @arm_sme_move_tile_slice_to_vector_i32(%tile : vector<[4]x[4]xi32>, %t
// -----
// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_i64
-// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
-func.func @arm_sme_move_tile_slice_to_vector_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index) -> vector<[2]xi64> {
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi64>, vector<[2]xi1>, i32) -> vector<[2]xi64>
+func.func @arm_sme_move_tile_slice_to_vector_i64(%tile_slice_index : index) -> vector<[2]xi64> {
+ %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xi64> from vector<[2]x[2]xi64>
return %slice : vector<[2]xi64>
}
@@ -467,8 +503,9 @@ func.func @arm_sme_move_tile_slice_to_vector_i64(%tile : vector<[2]x[2]xi64>, %t
// -----
// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_i128
-// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
-func.func @arm_sme_move_tile_slice_to_vector_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index) -> vector<[1]xi128> {
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[1]xi128>, vector<[1]xi1>, i32) -> vector<[1]xi128>
+func.func @arm_sme_move_tile_slice_to_vector_i128(%tile_slice_index : index) -> vector<[1]xi128> {
+ %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[1]xi128> from vector<[1]x[1]xi128>
return %slice : vector<[1]xi128>
}
@@ -476,8 +513,9 @@ func.func @arm_sme_move_tile_slice_to_vector_i128(%tile : vector<[1]x[1]xi128>,
// -----
// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_f16
-// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
-func.func @arm_sme_move_tile_slice_to_vector_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index) -> vector<[8]xf16> {
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xf16>, vector<[8]xi1>, i32) -> vector<[8]xf16>
+func.func @arm_sme_move_tile_slice_to_vector_f16(%tile_slice_index : index) -> vector<[8]xf16> {
+ %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[8]xf16> from vector<[8]x[8]xf16>
return %slice : vector<[8]xf16>
}
@@ -485,8 +523,9 @@ func.func @arm_sme_move_tile_slice_to_vector_f16(%tile : vector<[8]x[8]xf16>, %t
// -----
// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_bf16
-// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
-func.func @arm_sme_move_tile_slice_to_vector_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) -> vector<[8]xbf16> {
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xbf16>, vector<[8]xi1>, i32) -> vector<[8]xbf16>
+func.func @arm_sme_move_tile_slice_to_vector_bf16(%tile_slice_index : index) -> vector<[8]xbf16> {
+ %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[8]xbf16> from vector<[8]x[8]xbf16>
return %slice : vector<[8]xbf16>
}
@@ -494,8 +533,9 @@ func.func @arm_sme_move_tile_slice_to_vector_bf16(%tile : vector<[8]x[8]xbf16>,
// -----
// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_f32
-// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
-func.func @arm_sme_move_tile_slice_to_vector_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[4]xf32> {
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xf32>, vector<[4]xi1>, i32) -> vector<[4]xf32>
+func.func @arm_sme_move_tile_slice_to_vector_f32(%tile_slice_index : index) -> vector<[4]xf32> {
+ %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[4]xf32> from vector<[4]x[4]xf32>
return %slice : vector<[4]xf32>
}
@@ -503,8 +543,9 @@ func.func @arm_sme_move_tile_slice_to_vector_f32(%tile : vector<[4]x[4]xf32>, %t
// -----
// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_f64
-// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
-func.func @arm_sme_move_tile_slice_to_vector_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index) -> vector<[2]xf64> {
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xf64>, vector<[2]xi1>, i32) -> vector<[2]xf64>
+func.func @arm_sme_move_tile_slice_to_vector_f64(%tile_slice_index : index) -> vector<[2]xf64> {
+ %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
return %slice : vector<[2]xf64>
}
@@ -512,8 +553,9 @@ func.func @arm_sme_move_tile_slice_to_vector_f64(%tile : vector<[2]x[2]xf64>, %t
// -----
// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_ver_i128
-// CHECK: "arm_sme.intr.read.vert"({{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
-func.func @arm_sme_move_tile_slice_to_vector_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index) -> vector<[1]xi128> {
+// CHECK: "arm_sme.intr.read.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[1]xi128>, vector<[1]xi1>, i32) -> vector<[1]xi128>
+func.func @arm_sme_move_tile_slice_to_vector_ver_i128(%tile_slice_index : index) -> vector<[1]xi128> {
+ %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[1]xi128> from vector<[1]x[1]xi128>
return %slice : vector<[1]xi128>
}
diff --git a/mlir/test/Dialect/ArmSME/canonicalize.mlir b/mlir/test/Dialect/ArmSME/canonicalize.mlir
index 06bbd3050fdecec..b7ba3f728c705a3 100644
--- a/mlir/test/Dialect/ArmSME/canonicalize.mlir
+++ b/mlir/test/Dialect/ArmSME/canonicalize.mlir
@@ -1,27 +1,25 @@
// RUN: mlir-opt -canonicalize -split-input-file -verify-diagnostics %s | mlir-opt | FileCheck %s
-// -----
-
-// CHECK-LABEL: @cast_vector_to_tile__cast_tile_to_vector
-// CHECK-SAME: %[[TILE_ID:.*]]: i8
-func.func @cast_vector_to_tile__cast_tile_to_vector(%tile_id_0 : i8) -> i8 {
- // CHECK-NOT: arm_sme.cast_tile_to_vector
- // CHECK-NOT: arm_sme.cast_vector_to_tile
- // CHECK-NEXT: return %[[TILE_ID]] : i8
- %tile = arm_sme.cast_tile_to_vector %tile_id_0 : i8 to vector<[16]x[16]xi8>
- %tile_id_1 = arm_sme.cast_vector_to_tile %tile : vector<[16]x[16]xi8> to i8
- return %tile_id_1 : i8
-}
+// This tests that the `arm_sme.materialize_ssa_tile` placeholder is removed
+// once it becomes unused, after lowering to control flow.
// -----
-// CHECK-LABEL: @cast_tile_to_vector__cast_vector_to_tile
-// CHECK-SAME: %[[TILE:.*]]: vector<[16]x[16]xi8>
-func.func @cast_tile_to_vector__cast_vector_to_tile(%tile_0 : vector<[16]x[16]xi8>) -> vector<[16]x[16]xi8> {
- // CHECK-NOT: arm_sme.cast_vector_to_tile
- // CHECK-NOT: arm_sme.cast_tile_to_vector
- // CHECK-NEXT: return %[[TILE]] : vector<[16]x[16]xi8>
- %tile_id = arm_sme.cast_vector_to_tile %tile_0 : vector<[16]x[16]xi8> to i8
- %tile_1 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi8>
- return %tile_1 : vector<[16]x[16]xi8>
+// CHECK-LABEL: @unused_materialize_ssa_tile_is_removed_from_blocks
+// CHECK-NOT: arm_sme.materialize_ssa_tile
+// CHECK-NOT: vector<[4]x[4]xf32>
+func.func @unused_materialize_ssa_tile_is_removed_from_blocks(%arg0: memref<?x?xi32>) {
+ %c10 = arith.constant 10 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %tile = arm_sme.materialize_ssa_tile : vector<[4]x[4]xf32>
+ cf.br ^bb1(%c0, %tile : index, vector<[4]x[4]xf32>)
+^bb1(%1: index, %2: vector<[4]x[4]xf32>): // 2 preds: ^bb0, ^bb2
+ %3 = arith.cmpi slt, %1, %c10 : index
+ cf.cond_br %3, ^bb2, ^bb3
+^bb2: // pred: ^bb1
+ %4 = arith.addi %1, %c1 : index
+ cf.br ^bb1(%4, %tile : index, vector<[4]x[4]xf32>)
+^bb3: // pred: ^bb1
+ return
}
diff --git a/mlir/test/Dialect/ArmSME/cse.mlir b/mlir/test/Dialect/ArmSME/cse.mlir
index 734bd9f15c8dea1..74e7293eaeca5fc 100644
--- a/mlir/test/Dialect/ArmSME/cse.mlir
+++ b/mlir/test/Dialect/ArmSME/cse.mlir
@@ -1,16 +1,30 @@
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(cse))' | FileCheck %s
-// This test is checking that CSE does not remove 'arm_sme.get_tile_id' ops as
+// This test is checking that CSE does not remove 'arm_sme.zero/get_tile' ops as
// duplicates.
-// CHECK-LABEL: @get_tile_id
-// CHECK: %[[TILE_ID_0:.*]] = arm_sme.get_tile_id : i32
-// CHECK: %[[TILE_ID_1:.*]] = arm_sme.get_tile_id : i32
-// CHECK: "prevent.dce"(%[[TILE_ID_0]]) : (i32) -> ()
-// CHECK: "prevent.dce"(%[[TILE_ID_1]]) : (i32) -> ()
-func.func @get_tile_id() {
- %tile_id_1 = arm_sme.get_tile_id : i32
- %tile_id_2 = arm_sme.get_tile_id : i32
- "prevent.dce"(%tile_id_1) : (i32) -> ()
- "prevent.dce"(%tile_id_2) : (i32) -> ()
+
+// CHECK-LABEL: @zero_tile
+// CHECK: %[[TILE_0:.*]] = arm_sme.zero : vector<[4]x[4]xi32>
+// CHECK: %[[TILE_1:.*]] = arm_sme.zero : vector<[4]x[4]xi32>
+// CHECK: "prevent.dce"(%[[TILE_0]]) : (vector<[4]x[4]xi32>) -> ()
+// CHECK: "prevent.dce"(%[[TILE_1]]) : (vector<[4]x[4]xi32>) -> ()
+func.func @zero_tile() {
+ %tile_1 = arm_sme.zero : vector<[4]x[4]xi32>
+ %tile_2 = arm_sme.zero : vector<[4]x[4]xi32>
+ "prevent.dce"(%tile_1) : (vector<[4]x[4]xi32>) -> ()
+ "prevent.dce"(%tile_2) : (vector<[4]x[4]xi32>) -> ()
+ return
+}
+
+// CHECK-LABEL: @get_tile
+// CHECK: %[[TILE_0:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
+// CHECK: %[[TILE_1:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
+// CHECK: "prevent.dce"(%[[TILE_0]]) : (vector<[4]x[4]xi32>) -> ()
+// CHECK: "prevent.dce"(%[[TILE_1]]) : (vector<[4]x[4]xi32>) -> ()
+func.func @get_tile() {
+ %tile_1 = arm_sme.get_tile : vector<[4]x[4]xi32>
+ %tile_2 = arm_sme.get_tile : vector<[4]x[4]xi32>
+ "prevent.dce"(%tile_1) : (vector<[4]x[4]xi32>) -> ()
+ "prevent.dce"(%tile_2) : (vector<[4]x[4]xi32>) -> ()
return
}
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index d35eb7f7bccc75e..85b95a8b6cf12b7 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -1,89 +1,49 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
//===----------------------------------------------------------------------===//
-// arm_sme.cast_tile_to_vector
+// arm_sme.get_tile
//===----------------------------------------------------------------------===//
// -----
-func.func @arm_sme_cast_tile_to_vector__bad_tile_id_bitwidth(%tile_id : i8) -> vector<[8]x[8]xi16> {
- // expected-error at +1 {{op failed to verify that `tile_id` has the same number of bits as elements in `vector`}}
- %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[8]x[8]xi16>
- return %0 : vector<[8]x[8]xi16>
-}
-
-// -----
-
-func.func @arm_sme_cast_tile_to_vector__bad_vector_type_rank_1(%tile_id : i8) -> vector<[16]xi8> {
+func.func @arm_sme_get_tile__bad_vector_type_rank_1() -> vector<[16]xi8> {
// expected-error at +1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[16]xi8>'}}
- %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]xi8>
+ %0 = arm_sme.get_tile : vector<[16]xi8>
return %0 : vector<[16]xi8>
}
// -----
-func.func @arm_sme_cast_tile_to_vector__bad_vector_type_i4(%tile_id : i8) -> vector<[16]x[16]xi4> {
+func.func @arm_sme_get_tile__bad_vector_type_i4() -> vector<[16]x[16]xi4> {
// expected-error at +1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[16]x[16]xi4>'}}
- %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi4>
+ %0 = arm_sme.get_tile : vector<[16]x[16]xi4>
return %0 : vector<[16]x[16]xi4>
}
// -----
-func.func @arm_sme_cast_tile_to_vector__bad_vector_type_non_scalable_dim_0(%tile_id : i8) -> vector<16x[16]xi8> {
+func.func @arm_sme_get_tile__bad_vector_type_non_scalable_dim_0() -> vector<16x[16]xi8> {
// expected-error at +1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<16x[16]xi8>'}}
- %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<16x[16]xi8>
+ %0 = arm_sme.get_tile : vector<16x[16]xi8>
return %0 : vector<16x[16]xi8>
}
// -----
-func.func @arm_sme_cast_tile_to_vector__bad_vector_type_non_scalable_dim_1(%tile_id : i8) -> vector<[16]x16xi8> {
+func.func @arm_sme_get_tile__bad_vector_type_non_scalable_dim_1() -> vector<[16]x16xi8> {
// expected-error at +1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[16]x16xi8>'}}
- %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x16xi8>
+ %0 = arm_sme.get_tile : vector<[16]x16xi8>
return %0 : vector<[16]x16xi8>
}
// -----
-func.func @arm_sme_cast_tile_to_vector_bad_shape(%tile_id : i8) -> vector<[4]x[16]xi8> {
+func.func @arm_sme_get_tile__bad_shape(%tile_id : i8) -> vector<[4]x[16]xi8> {
// expected-error at +1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[4]x[16]xi8>'}}
- %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[4]x[16]xi8>
+ %0 = arm_sme.get_tile : vector<[4]x[16]xi8>
return %0 : vector<[4]x[16]xi8>
}
-//===----------------------------------------------------------------------===//
-// arm_sme.cast_vector_to_tile
-//===----------------------------------------------------------------------===//
-
-// -----
-
-func.func @arm_sme_cast_vector_to_tile__bad_tile_id_bitwidth(%vector : vector<[1]x[1]xi128>) -> i32 {
- // expected-error at +1 {{op failed to verify that `tile_id` has the same number of bits as elements in `vector`}}
- %0 = arm_sme.cast_vector_to_tile %vector : vector<[1]x[1]xi128> to i32
- return %0 : i32
-}
-
-// -----
-
-func.func @arm_sme_cast_vector_to_tile__bad_rank_1d(%vector : vector<[16]xi8>) -> i8 {
- // expected-error at +1 {{op operand #0 must be a vector type that fits into a SME tile, but got 'vector<[16]xi8>'}}
- %0 = arm_sme.cast_vector_to_tile %vector : vector<[16]xi8> to i8
- return %0 : i8
-}
-
-//===----------------------------------------------------------------------===//
-// arm_sme.get_tile_id
-//===----------------------------------------------------------------------===//
-
-// -----
-
-func.func @arm_sme_get_tile_id__bad_type() -> i1 {
- // expected-error at +1 {{op result #0 must be an identifier of a virtual tile (of a size) within the ZA storage}}
- %0 = arm_sme.get_tile_id : i1
- return %0 : i1
-}
-
//===----------------------------------------------------------------------===//
// arm_sme.move_vector_to_tile_slice
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index 1dbcc32e9259fc5..58ff7ef4d8340ec 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -1,197 +1,78 @@
// RUN: mlir-opt -split-input-file -verify-diagnostics %s | mlir-opt | FileCheck %s
//===----------------------------------------------------------------------===//
-// arm_sme.cast_tile_to_vector
+// arm_sme.get_tile
//===----------------------------------------------------------------------===//
-func.func @arm_sme_cast_tile_to_vector_i8(%tile_id : i8) -> vector<[16]x[16]xi8> {
- // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i8 to vector<[16]x[16]xi8>
- %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi8>
- return %0 : vector<[16]x[16]xi8>
-}
-
-// -----
-
-func.func @arm_sme_cast_tile_to_vector_i16(%tile_id : i16) -> vector<[8]x[8]xi16> {
- // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i16 to vector<[8]x[8]xi16>
- %0 = arm_sme.cast_tile_to_vector %tile_id : i16 to vector<[8]x[8]xi16>
- return %0 : vector<[8]x[8]xi16>
-}
-
-// -----
-
-func.func @arm_sme_cast_tile_to_vector_i32(%tile_id : i32) -> vector<[4]x[4]xi32> {
- // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i32 to vector<[4]x[4]xi32>
- %0 = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
- return %0 : vector<[4]x[4]xi32>
-}
-
-// -----
-
-func.func @arm_sme_cast_tile_to_vector_i64(%tile_id : i64) -> vector<[2]x[2]xi64> {
- // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i64 to vector<[2]x[2]xi64>
- %0 = arm_sme.cast_tile_to_vector %tile_id : i64 to vector<[2]x[2]xi64>
- return %0 : vector<[2]x[2]xi64>
-}
-
-// -----
-
-func.func @arm_sme_cast_tile_to_vector_i128(%tile_id : i128) -> vector<[1]x[1]xi128> {
- // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i128 to vector<[1]x[1]xi128>
- %0 = arm_sme.cast_tile_to_vector %tile_id : i128 to vector<[1]x[1]xi128>
- return %0 : vector<[1]x[1]xi128>
-}
-
-// -----
-
-func.func @arm_sme_cast_tile_to_vector_f16(%tile_id : i16) -> vector<[8]x[8]xf16> {
- // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i16 to vector<[8]x[8]xf16>
- %0 = arm_sme.cast_tile_to_vector %tile_id : i16 to vector<[8]x[8]xf16>
- return %0 : vector<[8]x[8]xf16>
-}
-
-// -----
-
-func.func @arm_sme_cast_tile_to_vector_bf16(%tile_id : i16) -> vector<[8]x[8]xbf16> {
- // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i16 to vector<[8]x[8]xbf16>
- %0 = arm_sme.cast_tile_to_vector %tile_id : i16 to vector<[8]x[8]xbf16>
- return %0 : vector<[8]x[8]xbf16>
-}
-
-// -----
-
-func.func @arm_sme_cast_tile_to_vector_f32(%tile_id : i32) -> vector<[4]x[4]xf32> {
- // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i32 to vector<[4]x[4]xf32>
- %0 = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xf32>
- return %0 : vector<[4]x[4]xf32>
-}
-
-// -----
-
-func.func @arm_sme_cast_tile_to_vector_f64(%tile_id : i64) -> vector<[2]x[2]xf64> {
- // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i64 to vector<[2]x[2]xf64>
- %0 = arm_sme.cast_tile_to_vector %tile_id : i64 to vector<[2]x[2]xf64>
- return %0 : vector<[2]x[2]xf64>
-}
-
-//===----------------------------------------------------------------------===//
-// arm_sme.cast_vector_to_tile
-//===----------------------------------------------------------------------===//
-
-// -----
-
-func.func @arm_sme_cast_vector_to_tile_i8(%vector : vector<[16]x[16]xi8>) -> i8 {
- // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[16]x[16]xi8> to i8
- %0 = arm_sme.cast_vector_to_tile %vector : vector<[16]x[16]xi8> to i8
- return %0 : i8
-}
-
-// -----
-
-func.func @arm_sme_cast_vector_to_tile_i16(%vector : vector<[8]x[8]xi16>) -> i16 {
- // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[8]x[8]xi16> to i16
- %0 = arm_sme.cast_vector_to_tile %vector : vector<[8]x[8]xi16> to i16
- return %0 : i16
-}
-// -----
-
-func.func @arm_sme_cast_vector_to_tile_i32(%vector : vector<[4]x[4]xi32>) -> i32 {
- // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[4]x[4]xi32> to i32
- %0 = arm_sme.cast_vector_to_tile %vector : vector<[4]x[4]xi32> to i32
- return %0 : i32
-}
-
-// -----
-
-func.func @arm_sme_cast_vector_to_tile_i64(%vector : vector<[2]x[2]xi64>) -> i64 {
- // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[2]x[2]xi64> to i64
- %0 = arm_sme.cast_vector_to_tile %vector : vector<[2]x[2]xi64> to i64
- return %0 : i64
-}
-
-// -----
-
-func.func @arm_sme_cast_vector_to_tile_i128(%vector : vector<[1]x[1]xi128>) -> i128 {
- // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[1]x[1]xi128> to i128
- %0 = arm_sme.cast_vector_to_tile %vector : vector<[1]x[1]xi128> to i128
- return %0 : i128
+func.func @arm_sme_get_tile_i8() {
+ // CHECK: arm_sme.get_tile : vector<[16]x[16]xi8>
+ %0 = arm_sme.get_tile : vector<[16]x[16]xi8>
+ return
}
// -----
-func.func @arm_sme_cast_vector_to_tile_f16(%vector : vector<[8]x[8]xf16>) -> i16 {
- // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[8]x[8]xf16> to i16
- %0 = arm_sme.cast_vector_to_tile %vector : vector<[8]x[8]xf16> to i16
- return %0 : i16
+func.func @arm_sme_get_tile_i16() {
+ // CHECK: arm_sme.get_tile : vector<[8]x[8]xi16>
+ %0 = arm_sme.get_tile : vector<[8]x[8]xi16>
+ return
}
// -----
-func.func @arm_sme_cast_vector_to_tile_bf16(%vector : vector<[8]x[8]xbf16>) -> i16 {
- // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[8]x[8]xbf16> to i16
- %0 = arm_sme.cast_vector_to_tile %vector : vector<[8]x[8]xbf16> to i16
- return %0 : i16
+func.func @arm_sme_get_tile_i32() {
+ // CHECK: arm_sme.get_tile : vector<[4]x[4]xi32>
+ %0 = arm_sme.get_tile : vector<[4]x[4]xi32>
+ return
}
// -----
-func.func @arm_sme_cast_vector_to_tile_f32(%vector : vector<[4]x[4]xf32>) -> i32 {
- // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[4]x[4]xf32> to i32
- %0 = arm_sme.cast_vector_to_tile %vector : vector<[4]x[4]xf32> to i32
- return %0 : i32
+func.func @arm_sme_get_tile_i64() {
+ // CHECK: arm_sme.get_tile : vector<[2]x[2]xi64>
+ %0 = arm_sme.get_tile : vector<[2]x[2]xi64>
+ return
}
// -----
-func.func @arm_sme_cast_vector_to_tile_f64(%vector : vector<[2]x[2]xf64>) -> i64 {
- // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[2]x[2]xf64> to i64
- %0 = arm_sme.cast_vector_to_tile %vector : vector<[2]x[2]xf64> to i64
- return %0 : i64
-}
-
-//===----------------------------------------------------------------------===//
-// arm_sme.get_tile_id
-//===----------------------------------------------------------------------===//
-
-// -----
-
-func.func @arm_sme_get_tile_id_i8() -> i8 {
- // CHECK: arm_sme.get_tile_id : i8
- %0 = arm_sme.get_tile_id : i8
- return %0 : i8
+func.func @arm_sme_get_tile_i128() {
+ // CHECK: arm_sme.get_tile : vector<[1]x[1]xi128>
+ %0 = arm_sme.get_tile : vector<[1]x[1]xi128>
+ return
}
// -----
-func.func @arm_sme_get_tile_id_i16() -> i16 {
- // CHECK: arm_sme.get_tile_id : i16
- %0 = arm_sme.get_tile_id : i16
- return %0 : i16
+func.func @arm_sme_get_tile_f16() {
+ // CHECK: arm_sme.get_tile : vector<[8]x[8]xf16>
+ %0 = arm_sme.get_tile : vector<[8]x[8]xf16>
+ return
}
// -----
-func.func @arm_sme_get_tile_id_i32() -> i32 {
- // CHECK: arm_sme.get_tile_id : i32
- %0 = arm_sme.get_tile_id : i32
- return %0 : i32
+func.func @arm_sme_get_tile_bf16() {
+ // CHECK: arm_sme.get_tile : vector<[8]x[8]xbf16>
+ %0 = arm_sme.get_tile : vector<[8]x[8]xbf16>
+ return
}
// -----
-func.func @arm_sme_get_tile_id_i64() -> i64 {
- // CHECK: arm_sme.get_tile_id : i64
- %0 = arm_sme.get_tile_id : i64
- return %0 : i64
+func.func @arm_sme_get_tile_f32() {
+ // CHECK: arm_sme.get_tile : vector<[4]x[4]xf32>
+ %0 = arm_sme.get_tile : vector<[4]x[4]xf32>
+ return
}
// -----
-func.func @arm_sme_get_tile_id_i128() -> i128 {
- // CHECK: arm_sme.get_tile_id : i128
- %0 = arm_sme.get_tile_id : i128
- return %0 : i128
+func.func @arm_sme_get_tile_f64() {
+ // CHECK: arm_sme.get_tile : vector<[2]x[2]xf64>
+ %0 = arm_sme.get_tile : vector<[2]x[2]xf64>
+ return
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/ArmSME/tile-allocation.mlir b/mlir/test/Dialect/ArmSME/tile-allocation.mlir
index a481516d4c15f93..1f895e4984ba844 100644
--- a/mlir/test/Dialect/ArmSME/tile-allocation.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-allocation.mlir
@@ -6,17 +6,17 @@
// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65534 : i32}
func.func @mixed_tiles() {
// ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
- // CHECK-NEXT: arith.constant 0
- %za0_h = arm_sme.get_tile_id : i16
+ // CHECK-NEXT: tile_id = 0
+ %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
// ZA1.Q, ZA5.Q, ZA9.Q, ZA13.Q
- // CHECK-NEXT: arith.constant 1
- %za1_s = arm_sme.get_tile_id : i32
+ // CHECK-NEXT: tile_id = 1
+ %za1_s = arm_sme.get_tile : vector<[4]x[4]xi32>
// ZA3.D ZA3.Q, ZA11.Q
- // CHECK-NEXT: arith.constant 3
- %za3_d = arm_sme.get_tile_id : i64
+ // CHECK-NEXT: tile_id = 3
+ %za3_d = arm_sme.get_tile : vector<[2]x[2]xi64>
// ZA7.Q
- // CHECK-NEXT: arith.constant 7
- %za7_q = arm_sme.get_tile_id : i128
+ // CHECK-NEXT: tile_id = 7
+ %za7_q = arm_sme.get_tile : vector<[1]x[1]xi128>
// ZA15.Q is still free.
return
}
@@ -26,28 +26,26 @@ func.func @mixed_tiles() {
// CHECK-LABEL: za_b
// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
func.func @za_b() {
- // CHECK-NEXT: arith.constant 0
- %za0_b = arm_sme.get_tile_id : i8
+ // CHECK-NEXT: tile_id = 0
+ %za0_b = arm_sme.get_tile : vector<[16]x[16]xi8>
return
}
// -----
func.func @za_b__out_of_tiles() {
- %za0_b = arm_sme.get_tile_id : i8
- // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+ %za0_b = arm_sme.get_tile : vector<[16]x[16]xi8>
// expected-error at +1 {{ran out of SME virtual tiles!}}
- %next_tile = arm_sme.get_tile_id : i8
+ %next_tile = arm_sme.get_tile : vector<[16]x[16]xi8>
return
}
// -----
func.func @za_b_overlapping_za_q() {
- %za0_b = arm_sme.get_tile_id : i8
- // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+ %za0_b = arm_sme.get_tile : vector<[16]x[16]xi8>
// expected-error at +1 {{ran out of SME virtual tiles!}}
- %next_tile = arm_sme.get_tile_id : i128
+ %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
return
}
@@ -56,8 +54,8 @@ func.func @za_b_overlapping_za_q() {
// CHECK-LABEL: za0_h
// CHECK-SAME: attributes {arm_sme.tiles_in_use = 43690 : i32}
func.func @za0_h() {
- // CHECK-NEXT: arith.constant 0
- %za0_h = arm_sme.get_tile_id : i16
+ // CHECK-NEXT: tile_id = 0
+ %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
return
}
@@ -66,21 +64,23 @@ func.func @za0_h() {
// CHECK-LABEL: za_h
// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
func.func @za_h() {
- // CHECK-NEXT: arith.constant 0
- %za0_h = arm_sme.get_tile_id : i16
- // CHECK-NEXT: arith.constant 1
- %za1_h = arm_sme.get_tile_id : i16
+ // CHECK-NEXT: tile_id = 0
+ %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
+ // CHECK-NEXT: tile_id = 1
+ %za1_h = arm_sme.get_tile : vector<[8]x[8]xi16>
return
}
// -----
+// CHECK-LABEL: za_h__out_of_tiles
func.func @za_h__out_of_tiles() {
- %za0_h = arm_sme.get_tile_id : i16
- %za1_h = arm_sme.get_tile_id : i16
- // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+ // CHECK-NEXT: tile_id = 0
+ %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
+ // CHECK-NEXT: tile_id = 1
+ %za1_h = arm_sme.get_tile : vector<[8]x[8]xi16>
// expected-error at +1 {{ran out of SME virtual tiles!}}
- %next_tile = arm_sme.get_tile_id : i16
+ %next_tile = arm_sme.get_tile : vector<[8]x[8]xi16>
return
}
@@ -90,14 +90,14 @@ func.func @za_h__out_of_tiles() {
// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
func.func @za_h_overlapping_za_s() {
// ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
- // CHECK-NEXT: arith.constant 0
- %za0_h = arm_sme.get_tile_id : i16
+ // CHECK-NEXT: tile_id = 0
+ %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
// ZA1.Q, ZA5.Q, ZA9.Q, ZA13.Q
- // CHECK-NEXT: arith.constant 1
- %za1_s = arm_sme.get_tile_id : i32
+ // CHECK-NEXT: tile_id = 1
+ %za1_s = arm_sme.get_tile : vector<[4]x[4]xi32>
// ZA3.Q, ZA7.Q, ZA11.Q, ZA15.Q
- // CHECK-NEXT: arith.constant 3
- %za3_s = arm_sme.get_tile_id : i32
+ // CHECK-NEXT: tile_id = 3
+ %za3_s = arm_sme.get_tile : vector<[4]x[4]xi32>
return
}
@@ -107,38 +107,37 @@ func.func @za_h_overlapping_za_s() {
// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
func.func @za_h_overlapping_za_d() {
// ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
- // CHECK-NEXT: arith.constant 0
- %za0_h = arm_sme.get_tile_id : i16
+ // CHECK-NEXT: tile_id = 0
+ %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
// ZA1.Q, ZA9.Q
- // CHECK-NEXT: arith.constant 1
- %za1_d = arm_sme.get_tile_id : i64
+ // CHECK-NEXT: tile_id = 1
+ %za1_d = arm_sme.get_tile : vector<[2]x[2]xi64>
// ZA3.Q, ZA11.Q
- // CHECK-NEXT: arith.constant 3
- %za3_d = arm_sme.get_tile_id : i64
+ // CHECK-NEXT: tile_id = 3
+ %za3_d = arm_sme.get_tile : vector<[2]x[2]xi64>
// ZA5.Q, ZA13.Q
- // CHECK-NEXT: arith.constant 5
- %za5_d = arm_sme.get_tile_id : i64
+ // CHECK-NEXT: tile_id = 5
+ %za5_d = arm_sme.get_tile : vector<[2]x[2]xi64>
// ZA7.Q, ZA15.Q
- // CHECK-NEXT: arith.constant 7
- %za7_d = arm_sme.get_tile_id : i64
+ // CHECK-NEXT: tile_id = 7
+ %za7_d = arm_sme.get_tile : vector<[2]x[2]xi64>
return
}
// -----
func.func @za_h_overlapping_za_q() {
- %za0_h = arm_sme.get_tile_id : i16
- %za0_q = arm_sme.get_tile_id : i128
- %za2_q = arm_sme.get_tile_id : i128
- %za4_q = arm_sme.get_tile_id : i128
- %za6_q = arm_sme.get_tile_id : i128
- %za8_q = arm_sme.get_tile_id : i128
- %za10_q = arm_sme.get_tile_id : i128
- %za12_q = arm_sme.get_tile_id : i128
- %za14_q = arm_sme.get_tile_id : i128
- // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+ %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
+ %za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za2_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za4_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za6_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za8_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za10_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za12_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
// expected-error at +1 {{ran out of SME virtual tiles!}}
- %next_tile = arm_sme.get_tile_id : i128
+ %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
return
}
@@ -147,8 +146,8 @@ func.func @za_h_overlapping_za_q() {
// CHECK-LABEL: za0_s
// CHECK-SAME: attributes {arm_sme.tiles_in_use = 34952 : i32}
func.func @za0_s() {
- // CHECK-NEXT: arith.constant 0
- %za0_s = arm_sme.get_tile_id : i32
+ // CHECK-NEXT: tile_id = 0
+ %za0_s = arm_sme.get_tile : vector<[4]x[4]xi32>
return
}
@@ -157,27 +156,26 @@ func.func @za0_s() {
// CHECK-LABEL: za_s
// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
func.func @za_s() {
- // CHECK-NEXT: arith.constant 0
- %za0_s = arm_sme.get_tile_id : i32
- // CHECK-NEXT: arith.constant 1
- %za1_s = arm_sme.get_tile_id : i32
- // CHECK-NEXT: arith.constant 2
- %za2_s = arm_sme.get_tile_id : i32
- // CHECK-NEXT: arith.constant 3
- %za3_s = arm_sme.get_tile_id : i32
+ // CHECK-NEXT: tile_id = 0
+ %za0_s = arm_sme.get_tile : vector<[4]x[4]xi32>
+ // CHECK-NEXT: tile_id = 1
+ %za1_s = arm_sme.get_tile : vector<[4]x[4]xi32>
+ // CHECK-NEXT: tile_id = 2
+ %za2_s = arm_sme.get_tile : vector<[4]x[4]xi32>
+ // CHECK-NEXT: tile_id = 3
+ %za3_s = arm_sme.get_tile : vector<[4]x[4]xi32>
return
}
// -----
func.func @za_s__out_of_tiles() {
- %za0_s = arm_sme.get_tile_id : i32
- %za1_s = arm_sme.get_tile_id : i32
- %za2_s = arm_sme.get_tile_id : i32
- %za3_s = arm_sme.get_tile_id : i32
- // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+ %za0_s = arm_sme.get_tile : vector<[4]x[4]xi32>
+ %za1_s = arm_sme.get_tile : vector<[4]x[4]xi32>
+ %za2_s = arm_sme.get_tile : vector<[4]x[4]xi32>
+ %za3_s = arm_sme.get_tile : vector<[4]x[4]xi32>
// expected-error at +1 {{ran out of SME virtual tiles!}}
- %next_tile = arm_sme.get_tile_id : i32
+ %next_tile = arm_sme.get_tile : vector<[4]x[4]xi32>
return
}
@@ -187,42 +185,41 @@ func.func @za_s__out_of_tiles() {
// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
func.func @za_s_overlapping_za_d() {
// ZA0.Q, ZA4.Q, ZA8.Q, ZA12.Q
- // CHECK-NEXT: arith.constant 0
- %za0_s = arm_sme.get_tile_id : i32
+ // CHECK-NEXT: tile_id = 0
+ %za0_s = arm_sme.get_tile : vector<[4]x[4]xi32>
// ZA1.Q, ZA5.Q, ZA9.Q, ZA13.Q
- // CHECK-NEXT: arith.constant 1
- %za1_s = arm_sme.get_tile_id : i32
+ // CHECK-NEXT: tile_id = 1
+ %za1_s = arm_sme.get_tile : vector<[4]x[4]xi32>
// ZA2.Q, ZA6.Q, ZA10.Q, ZA14.Q
- // CHECK-NEXT: arith.constant 2
- %za2_s = arm_sme.get_tile_id : i32
+ // CHECK-NEXT: tile_id = 2
+ %za2_s = arm_sme.get_tile : vector<[4]x[4]xi32>
// ZA3.Q, ZA11.Q
- // CHECK-NEXT: arith.constant 3
- %za3_d = arm_sme.get_tile_id : i64
+ // CHECK-NEXT: tile_id = 3
+ %za3_d = arm_sme.get_tile : vector<[2]x[2]xi64>
// ZA7.Q, ZA15.Q
- // CHECK-NEXT: arith.constant 7
- %za7_d = arm_sme.get_tile_id : i64
+ // CHECK-NEXT: tile_id = 7
+ %za7_d = arm_sme.get_tile : vector<[2]x[2]xi64>
return
}
// -----
func.func @za_s_overlapping_za_q() {
- %za0_s = arm_sme.get_tile_id : i32
- %za1_q = arm_sme.get_tile_id : i128
- %za2_q = arm_sme.get_tile_id : i128
- %za3_q = arm_sme.get_tile_id : i128
- %za5_q = arm_sme.get_tile_id : i128
- %za6_q = arm_sme.get_tile_id : i128
- %za7_q = arm_sme.get_tile_id : i128
- %za9_q = arm_sme.get_tile_id : i128
- %za10_q = arm_sme.get_tile_id : i128
- %za11_q = arm_sme.get_tile_id : i128
- %za13_q = arm_sme.get_tile_id : i128
- %za14_q = arm_sme.get_tile_id : i128
- %za15_q = arm_sme.get_tile_id : i128
- // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+ %za0_s = arm_sme.get_tile : vector<[4]x[4]xi32>
+ %za1_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za2_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za3_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za5_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za6_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za7_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za9_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za10_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za11_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za13_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za15_q = arm_sme.get_tile : vector<[1]x[1]xi128>
// expected-error at +1 {{ran out of SME virtual tiles!}}
- %next_tile = arm_sme.get_tile_id : i128
+ %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
return
}
@@ -231,8 +228,8 @@ func.func @za_s_overlapping_za_q() {
// CHECK-LABEL: za0_d
// CHECK-SAME: attributes {arm_sme.tiles_in_use = 32896 : i32}
func.func @za0_d() {
- // CHECK-NEXT: arith.constant 0
- %za0_d = arm_sme.get_tile_id : i64
+ // CHECK-NEXT: tile_id = 0
+ %za0_d = arm_sme.get_tile : vector<[2]x[2]xi64>
return
}
@@ -241,63 +238,61 @@ func.func @za0_d() {
// CHECK-LABEL: za_d
// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
func.func @za_d() {
- // CHECK-NEXT: arith.constant 0
- %za0_d = arm_sme.get_tile_id : i64
- // CHECK-NEXT: arith.constant 1
- %za1_d = arm_sme.get_tile_id : i64
- // CHECK-NEXT: arith.constant 2
- %za2_d = arm_sme.get_tile_id : i64
- // CHECK-NEXT: arith.constant 3
- %za3_d = arm_sme.get_tile_id : i64
- // CHECK-NEXT: arith.constant 4
- %za4_d = arm_sme.get_tile_id : i64
- // CHECK-NEXT: arith.constant 5
- %za5_d = arm_sme.get_tile_id : i64
- // CHECK-NEXT: arith.constant 6
- %za6_d = arm_sme.get_tile_id : i64
- // CHECK-NEXT: arith.constant 7
- %za7_d = arm_sme.get_tile_id : i64
+ // CHECK-NEXT: tile_id = 0
+ %za0_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+ // CHECK-NEXT: tile_id = 1
+ %za1_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+ // CHECK-NEXT: tile_id = 2
+ %za2_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+ // CHECK-NEXT: tile_id = 3
+ %za3_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+ // CHECK-NEXT: tile_id = 4
+ %za4_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+ // CHECK-NEXT: tile_id = 5
+ %za5_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+ // CHECK-NEXT: tile_id = 6
+ %za6_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+ // CHECK-NEXT: tile_id = 7
+ %za7_d = arm_sme.get_tile : vector<[2]x[2]xi64>
return
}
// -----
func.func @za_d__out_of_tiles() {
- %za0_d = arm_sme.get_tile_id : i64
- %za1_d = arm_sme.get_tile_id : i64
- %za2_d = arm_sme.get_tile_id : i64
- %za3_d = arm_sme.get_tile_id : i64
- %za4_d = arm_sme.get_tile_id : i64
- %za5_d = arm_sme.get_tile_id : i64
- %za6_d = arm_sme.get_tile_id : i64
- %za7_d = arm_sme.get_tile_id : i64
- // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+ %za0_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+ %za1_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+ %za2_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+ %za3_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+ %za4_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+ %za5_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+ %za6_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+ %za7_d = arm_sme.get_tile : vector<[2]x[2]xi64>
// expected-error at +1 {{ran out of SME virtual tiles!}}
- %next_tile = arm_sme.get_tile_id : i64
+ %next_tile = arm_sme.get_tile : vector<[2]x[2]xi64>
return
}
// -----
func.func @za_d_overlapping_za_q() {
- %za0_d = arm_sme.get_tile_id : i64
- %za1_q = arm_sme.get_tile_id : i128
- %za2_q = arm_sme.get_tile_id : i128
- %za3_q = arm_sme.get_tile_id : i128
- %za4_q = arm_sme.get_tile_id : i128
- %za5_q = arm_sme.get_tile_id : i128
- %za6_q = arm_sme.get_tile_id : i128
- %za7_q = arm_sme.get_tile_id : i128
- %za9_q = arm_sme.get_tile_id : i128
- %za10_q = arm_sme.get_tile_id : i128
- %za11_q = arm_sme.get_tile_id : i128
- %za12_q = arm_sme.get_tile_id : i128
- %za13_q = arm_sme.get_tile_id : i128
- %za14_q = arm_sme.get_tile_id : i128
- %za15_q = arm_sme.get_tile_id : i128
- // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+ %za0_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+ %za1_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za2_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za3_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za4_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za5_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za6_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za7_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za9_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za10_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za11_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za12_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za13_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za15_q = arm_sme.get_tile : vector<[1]x[1]xi128>
// expected-error at +1 {{ran out of SME virtual tiles!}}
- %next_tile = arm_sme.get_tile_id : i128
+ %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
return
}
@@ -306,8 +301,8 @@ func.func @za_d_overlapping_za_q() {
// CHECK-LABEL: za0_q
// CHECK-SAME: attributes {arm_sme.tiles_in_use = 32768 : i32}
func.func @za0_q() {
- // CHECK-NEXT: arith.constant 0
- %za0_q = arm_sme.get_tile_id : i128
+ // CHECK-NEXT: tile_id = 0
+ %za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
return
}
@@ -316,62 +311,61 @@ func.func @za0_q() {
// CHECK-LABEL: za_q
// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
func.func @za_q() {
- // CHECK-NEXT: arith.constant 0
- %za0_q = arm_sme.get_tile_id : i128
- // CHECK-NEXT: arith.constant 1
- %za1_q = arm_sme.get_tile_id : i128
- // CHECK-NEXT: arith.constant 2
- %za2_q = arm_sme.get_tile_id : i128
- // CHECK-NEXT: arith.constant 3
- %za3_q = arm_sme.get_tile_id : i128
- // CHECK-NEXT: arith.constant 4
- %za4_q = arm_sme.get_tile_id : i128
- // CHECK-NEXT: arith.constant 5
- %za5_q = arm_sme.get_tile_id : i128
- // CHECK-NEXT: arith.constant 6
- %za6_q = arm_sme.get_tile_id : i128
- // CHECK-NEXT: arith.constant 7
- %za7_q = arm_sme.get_tile_id : i128
- // CHECK-NEXT: arith.constant 8
- %za8_q = arm_sme.get_tile_id : i128
- // CHECK-NEXT: arith.constant 9
- %za9_q = arm_sme.get_tile_id : i128
- // CHECK-NEXT: arith.constant 10
- %za10_q = arm_sme.get_tile_id : i128
- // CHECK-NEXT: arith.constant 11
- %za11_q = arm_sme.get_tile_id : i128
- // CHECK-NEXT: arith.constant 12
- %za12_q = arm_sme.get_tile_id : i128
- // CHECK-NEXT: arith.constant 13
- %za13_q = arm_sme.get_tile_id : i128
- // CHECK-NEXT: arith.constant 14
- %za14_q = arm_sme.get_tile_id : i128
- // CHECK-NEXT: arith.constant 15
- %za15_q = arm_sme.get_tile_id : i128
+ // CHECK-NEXT: tile_id = 0
+ %za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 1
+ %za1_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 2
+ %za2_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 3
+ %za3_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 4
+ %za4_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 5
+ %za5_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 6
+ %za6_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 7
+ %za7_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 8
+ %za8_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 9
+ %za9_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 10
+ %za10_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 11
+ %za11_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 12
+ %za12_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 13
+ %za13_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 14
+ %za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 15
+ %za15_q = arm_sme.get_tile : vector<[1]x[1]xi128>
return
}
// -----
func.func @za_q__out_of_tiles() {
- %za0_q = arm_sme.get_tile_id : i128
- %za1_q = arm_sme.get_tile_id : i128
- %za2_q = arm_sme.get_tile_id : i128
- %za3_q = arm_sme.get_tile_id : i128
- %za4_q = arm_sme.get_tile_id : i128
- %za5_q = arm_sme.get_tile_id : i128
- %za6_q = arm_sme.get_tile_id : i128
- %za7_q = arm_sme.get_tile_id : i128
- %za8_q = arm_sme.get_tile_id : i128
- %za9_q = arm_sme.get_tile_id : i128
- %za10_q = arm_sme.get_tile_id : i128
- %za11_q = arm_sme.get_tile_id : i128
- %za12_q = arm_sme.get_tile_id : i128
- %za13_q = arm_sme.get_tile_id : i128
- %za14_q = arm_sme.get_tile_id : i128
- %za15_q = arm_sme.get_tile_id : i128
- // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+ %za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za1_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za2_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za3_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za4_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za5_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za6_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za7_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za8_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za9_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za10_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za11_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za12_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za13_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ %za15_q = arm_sme.get_tile : vector<[1]x[1]xi128>
// expected-error at +1 {{ran out of SME virtual tiles!}}
- %next_tile = arm_sme.get_tile_id : i128
+ %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
return
}
diff --git a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
index 2378f4234aef1ef..04412e4db1c5f3a 100644
--- a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
@@ -1,7 +1,4 @@
-// RUN: mlir-opt %s -convert-arm-sme-to-llvm \
-// RUN: -allocate-arm-sme-tiles -canonicalize \
-// RUN: -allow-unregistered-dialect \
-// RUN: | FileCheck %s
+// RUN: mlir-opt %s -allocate-arm-sme-tiles -convert-arm-sme-to-llvm -canonicalize | FileCheck %s
// This test verifies the tile mask operand of the zero intrinsic zeroes
// the correct tiles. Both integer and floating-point datatypes are checked.
@@ -10,13 +7,8 @@
// CHECK-LABEL: zero_za_b
func.func @zero_za_b() {
- // CHECK-DAG: %[[TILE_ID:.*]] = arith.constant 0 : i8
- // CHECK-DAG: %[[ZERO_MASK:.*]] = arith.constant 255 : i32
-
- // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK]]) : (i32) -> ()
- // CHECK-NEXT: %[[ZERO_ZA0B:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
+ // CHECK: "arm_sme.intr.zero"() <{tile_mask = 255 : i32}> : () -> ()
%zero_za0b = arm_sme.zero : vector<[16]x[16]xi8>
- "prevent.dce"(%zero_za0b) : (vector<[16]x[16]xi8>) -> ()
return
}
@@ -24,20 +16,10 @@ func.func @zero_za_b() {
// CHECK-LABEL: zero_za_h
func.func @zero_za_h() {
- // CHECK-DAG: %[[TILE_ID_ZA0H:.*]] = arith.constant 0 : i16
- // CHECK-DAG: %[[TILE_ID_ZA1H:.*]] = arith.constant 1 : i16
-
- // CHECK-DAG: %[[ZERO_MASK_ZA0H:.*]] = arith.constant 85 : i32
- // CHECK-DAG: %[[ZERO_MASK_ZA1H:.*]] = arith.constant 170 : i32
-
- // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA0H]]) : (i32) -> ()
- // CHECK-NEXT: %[[ZERO_ZA0H:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA0H]] : i16 to vector<[8]x[8]xi16>
+ // CHECK: "arm_sme.intr.zero"() <{tile_mask = 85 : i32}> : () -> ()
%zero_za0h = arm_sme.zero : vector<[8]x[8]xi16>
- "prevent.dce"(%zero_za0h) : (vector<[8]x[8]xi16>) -> ()
- // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA1H]]) : (i32) -> ()
- // CHECK-NEXT: %[[ZERO_ZA1H:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA1H]] : i16 to vector<[8]x[8]xf16>
+ // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 170 : i32}> : () -> ()
%zero_za1h = arm_sme.zero : vector<[8]x[8]xf16>
- "prevent.dce"(%zero_za1h) : (vector<[8]x[8]xf16>) -> ()
return
}
@@ -45,32 +27,14 @@ func.func @zero_za_h() {
// CHECK-LABEL: zero_za_s
func.func @zero_za_s() {
- // CHECK-DAG: %[[TILE_ID_ZA0S:.*]] = arith.constant 0 : i32
- // CHECK-DAG: %[[TILE_ID_ZA1S:.*]] = arith.constant 1 : i32
- // CHECK-DAG: %[[TILE_ID_ZA2S:.*]] = arith.constant 2 : i32
- // CHECK-DAG: %[[TILE_ID_ZA3S:.*]] = arith.constant 3 : i32
-
- // CHECK-DAG: %[[ZERO_MASK_ZA0S:.*]] = arith.constant 17 : i32
- // CHECK-DAG: %[[ZERO_MASK_ZA1S:.*]] = arith.constant 34 : i32
- // CHECK-DAG: %[[ZERO_MASK_ZA2S:.*]] = arith.constant 68 : i32
- // CHECK-DAG: %[[ZERO_MASK_ZA3S:.*]] = arith.constant 136 : i32
-
- // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA0S]]) : (i32) -> ()
- // CHECK-NEXT: %[[ZERO_ZA0S:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA0S]] : i32 to vector<[4]x[4]xi32>
+ // CHECK: arm_sme.intr.zero"() <{tile_mask = 17 : i32}> : () -> ()
%zero_za0s = arm_sme.zero : vector<[4]x[4]xi32>
- "prevent.dce"(%zero_za0s) : (vector<[4]x[4]xi32>) -> ()
- // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA1S]]) : (i32) -> ()
- // CHECK-NEXT: %[[ZERO_ZA1S:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA1S]] : i32 to vector<[4]x[4]xi32>
+ // CHECK-NEXT: arm_sme.intr.zero"() <{tile_mask = 34 : i32}> : () -> ()
%zero_za1s = arm_sme.zero : vector<[4]x[4]xi32>
- "prevent.dce"(%zero_za1s) : (vector<[4]x[4]xi32>) -> ()
- // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA2S]]) : (i32) -> ()
- // CHECK-NEXT: %[[ZERO_ZA2S:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA2S]] : i32 to vector<[4]x[4]xi32>
+ // CHECK-NEXT: arm_sme.intr.zero"() <{tile_mask = 68 : i32}> : () -> ()
%zero_za2s = arm_sme.zero : vector<[4]x[4]xi32>
- "prevent.dce"(%zero_za2s) : (vector<[4]x[4]xi32>) -> ()
- // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA3S]]) : (i32) -> ()
- // CHECK-NEXT: %[[ZERO_ZA3S:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA3S]] : i32 to vector<[4]x[4]xf32>
+ // CHECK-NEXT: arm_sme.intr.zero"() <{tile_mask = 136 : i32}> : () -> ()
%zero_za3s = arm_sme.zero : vector<[4]x[4]xf32>
- "prevent.dce"(%zero_za3s) : (vector<[4]x[4]xf32>) -> ()
return
}
@@ -78,55 +42,21 @@ func.func @zero_za_s() {
// CHECK-LABEL: zero_za_d
func.func @zero_za_d() {
- // CHECK-DAG: %[[TILE_ID_ZA0D:.*]] = arith.constant 0 : i64
- // CHECK-DAG: %[[TILE_ID_ZA1D:.*]] = arith.constant 1 : i64
- // CHECK-DAG: %[[TILE_ID_ZA2D:.*]] = arith.constant 2 : i64
- // CHECK-DAG: %[[TILE_ID_ZA3D:.*]] = arith.constant 3 : i64
- // CHECK-DAG: %[[TILE_ID_ZA4D:.*]] = arith.constant 4 : i64
- // CHECK-DAG: %[[TILE_ID_ZA5D:.*]] = arith.constant 5 : i64
- // CHECK-DAG: %[[TILE_ID_ZA6D:.*]] = arith.constant 6 : i64
- // CHECK-DAG: %[[TILE_ID_ZA7D:.*]] = arith.constant 7 : i64
-
- // CHECK-DAG: %[[ZERO_MASK_ZA0D:.*]] = arith.constant 1 : i32
- // CHECK-DAG: %[[ZERO_MASK_ZA1D:.*]] = arith.constant 2 : i32
- // CHECK-DAG: %[[ZERO_MASK_ZA2D:.*]] = arith.constant 4 : i32
- // CHECK-DAG: %[[ZERO_MASK_ZA3D:.*]] = arith.constant 8 : i32
- // CHECK-DAG: %[[ZERO_MASK_ZA4D:.*]] = arith.constant 16 : i32
- // CHECK-DAG: %[[ZERO_MASK_ZA5D:.*]] = arith.constant 32 : i32
- // CHECK-DAG: %[[ZERO_MASK_ZA6D:.*]] = arith.constant 64 : i32
- // CHECK-DAG: %[[ZERO_MASK_ZA7D:.*]] = arith.constant 128 : i32
-
- // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA0D]]) : (i32) -> ()
- // CHECK-NEXT: %[[ZERO_ZA0D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA0D]] : i64 to vector<[2]x[2]xi64>
+ // CHECK: "arm_sme.intr.zero"() <{tile_mask = 1 : i32}> : () -> ()
%zero_za0d = arm_sme.zero : vector<[2]x[2]xi64>
- "prevent.dce"(%zero_za0d) : (vector<[2]x[2]xi64>) -> ()
- // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA1D]]) : (i32) -> ()
- // CHECK-NEXT: %[[ZERO_ZA1D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA1D]] : i64 to vector<[2]x[2]xi64>
+ // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 2 : i32}> : () -> ()
%zero_za1d = arm_sme.zero : vector<[2]x[2]xi64>
- "prevent.dce"(%zero_za1d) : (vector<[2]x[2]xi64>) -> ()
- // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA2D]]) : (i32) -> ()
- // CHECK-NEXT: %[[ZERO_ZA2D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA2D]] : i64 to vector<[2]x[2]xi64>
+ // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 4 : i32}> : () -> ()
%zero_za2d = arm_sme.zero : vector<[2]x[2]xi64>
- "prevent.dce"(%zero_za2d) : (vector<[2]x[2]xi64>) -> ()
- // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA3D]]) : (i32) -> ()
- // CHECK-NEXT: %[[ZERO_ZA3D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA3D]] : i64 to vector<[2]x[2]xi64>
+ // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 8 : i32}> : () -> ()
%zero_za3d = arm_sme.zero : vector<[2]x[2]xi64>
- "prevent.dce"(%zero_za3d) : (vector<[2]x[2]xi64>) -> ()
- // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA4D]]) : (i32) -> ()
- // CHECK-NEXT: %[[ZERO_ZA4D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA4D]] : i64 to vector<[2]x[2]xi64>
+ // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 16 : i32}> : () -> ()
%zero_za4d = arm_sme.zero : vector<[2]x[2]xi64>
- "prevent.dce"(%zero_za4d) : (vector<[2]x[2]xi64>) -> ()
- // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA5D]]) : (i32) -> ()
- // CHECK-NEXT: %[[ZERO_ZA5D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA5D]] : i64 to vector<[2]x[2]xi64>
+ // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 32 : i32}> : () -> ()
%zero_za5d = arm_sme.zero : vector<[2]x[2]xi64>
- "prevent.dce"(%zero_za5d) : (vector<[2]x[2]xi64>) -> ()
- // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA6D]]) : (i32) -> ()
- // CHECK-NEXT: %[[ZERO_ZA6D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA6D]] : i64 to vector<[2]x[2]xi64>
+ // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 64 : i32}> : () -> ()
%zero_za6d = arm_sme.zero : vector<[2]x[2]xi64>
- "prevent.dce"(%zero_za6d) : (vector<[2]x[2]xi64>) -> ()
- // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA7D]]) : (i32) -> ()
- // CHECK-NEXT: %[[ZERO_ZA7D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA7D]] : i64 to vector<[2]x[2]xf64>
+ // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 128 : i32}> : () -> ()
%zero_za7d = arm_sme.zero : vector<[2]x[2]xf64>
- "prevent.dce"(%zero_za7d) : (vector<[2]x[2]xf64>) -> ()
return
}
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
index 77ac071ef67de9a..d16d250b70eb3f6 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-arm-sme-to-scf -convert-arm-sme-to-llvm -cse -canonicalize -split-input-file -allow-unregistered-dialect -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -convert-vector-to-arm-sme -allocate-arm-sme-tiles -convert-arm-sme-to-scf -convert-arm-sme-to-llvm -cse -canonicalize -split-input-file -allow-unregistered-dialect -verify-diagnostics | FileCheck %s
//===----------------------------------------------------------------------===//
// vector.transfer_write
@@ -10,13 +10,9 @@
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[MIN_SVL_B:.*]] = arith.constant 16 : index
-// CHECK-DAG: %[[C255:.*]] = arith.constant 255 : i32
// CHECK-DAG: %[[PTRUE_ALL:.*]] = arith.constant dense<true> : vector<[16]xi1>
// CHECK-DAG: %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64
-// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8
-// CHECK-DAG: %[[EXT_TILE_ID:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
-// CHECK-DAG: %[[TILE_MASK:.*]] = arith.shli %[[C255]], %[[EXT_TILE_ID]] : i32
-// CHECK-DAG: "arm_sme.intr.zero"(%[[TILE_MASK]]) : (i32) -> ()
+// CHECK-DAG: "arm_sme.intr.zero"() <{tile_mask = 255 : i32}> : () -> ()
// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE]], %[[MIN_SVL_B]] : index
// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] {
@@ -27,8 +23,7 @@
// CHECK-NEXT: %[[OFF1:.*]] = llvm.add %[[OFF0]], %[[C0_I64]] : i64
// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
// CHECK-NEXT: %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32
-// CHECK-NEXT: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
-// CHECK-NEXT: "arm_sme.intr.st1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+// CHECK-NEXT: "arm_sme.intr.st1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_SLICE_I32]]) <{tile_id = 0 : i32}> : (vector<[16]xi1>, !llvm.ptr, i32) -> ()
func.func @transfer_write_2d_zero_i8(%arg0 : memref<?x?xi8>) {
%c0 = arith.constant 0 : index
%cst = arith.constant dense<0> : vector<[16]x[16]xi8>
@@ -49,8 +44,6 @@ func.func @transfer_write_2d_zero_i8(%arg0 : memref<?x?xi8>) {
// CHECK-LABEL: @vector_load_i8_with_offset(
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi8>)
// CHECK-DAG: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8
-// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C123:.*]] = arith.constant 123 : index
@@ -68,10 +61,8 @@ func.func @transfer_write_2d_zero_i8(%arg0 : memref<?x?xi8>) {
// CHECK-NEXT: %[[OFF1:.*]] = llvm.add %[[OFF0]], %[[C0_I64]] : i64
// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
// CHECK-NEXT: %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32
-// CHECK-NEXT: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
-// CHECK-NEXT: "arm_sme.intr.ld1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+// CHECK-NEXT: "arm_sme.intr.ld1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_SLICE_I32]]) <{tile_id = 0 : i32}> : (vector<[16]xi1>, !llvm.ptr, i32) -> ()
// CHECK-NEXT: }
-// CHECK-NEXT: return %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8>
func.func @vector_load_i8_with_offset(%arg0 : memref<?x?xi8>) -> vector<[16]x[16]xi8> {
%c0 = arith.constant 0 : index
%c123 = arith.constant 123 : index
@@ -84,8 +75,6 @@ func.func @vector_load_i8_with_offset(%arg0 : memref<?x?xi8>) -> vector<[16]x[16
// CHECK-LABEL: @vector_load_i8_from_rank_1_memref(
// CHECK-SAME: %[[ARG0:.*]]: memref<?xi8>)
// CHECK-DAG: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?xi8> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
-// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8
-// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[MIN_SVL_B:.*]] = arith.constant 16 : index
@@ -98,10 +87,8 @@ func.func @vector_load_i8_with_offset(%arg0 : memref<?x?xi8>) -> vector<[16]x[16
// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[TILE_SLICE_IDX_I64]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
// CHECK-NEXT: %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32
-// CHECK-NEXT: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
-// CHECK-NEXT: "arm_sme.intr.ld1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+// CHECK-NEXT: "arm_sme.intr.ld1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_SLICE_I32]]) <{tile_id = 0 : i32}> : (vector<[16]xi1>, !llvm.ptr, i32) -> ()
// CHECK-NEXT: }
-// CHECK-NEXT: return %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8>
func.func @vector_load_i8_from_rank_1_memref(%arg0 : memref<?xi8>) -> vector<[16]x[16]xi8> {
%c0 = arith.constant 0 : index
%tile = vector.load %arg0[%c0] : memref<?xi8>, vector<[16]x[16]xi8>
@@ -113,12 +100,10 @@ func.func @vector_load_i8_from_rank_1_memref(%arg0 : memref<?xi8>) -> vector<[16
// CHECK-LABEL: @vector_load_i16(
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi16>)
-// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i16
-// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xi16>
// CHECK-DAG: %[[MIN_SVL_H:.*]] = arith.constant 8 : index
// CHECK: %[[SVL_H:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_H]] : index
-// CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i16 to i32
// CHECK: arm_sme.intr.ld1h.horiz
+// CHECK-SAME: tile_id = 0
func.func @vector_load_i16(%arg0 : memref<?x?xi16>) -> vector<[8]x[8]xi16> {
%c0 = arith.constant 0 : index
%tile = vector.load %arg0[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
@@ -129,13 +114,10 @@ func.func @vector_load_i16(%arg0 : memref<?x?xi16>) -> vector<[8]x[8]xi16> {
// CHECK-LABEL: @vector_load_i32(
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi32>)
-// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
-// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
// CHECK-DAG: %[[MIN_SVL_S:.*]] = arith.constant 4 : index
// CHECK: %[[SVL_S:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_S]] : index
-// CHECK-NOT: arith.extui %[[TILE_ID]]
-// CHECK-NOT: arith.trunci %[[TILE_ID]]
// CHECK: arm_sme.intr.ld1w.horiz
+// CHECK-SAME: tile_id = 0
func.func @vector_load_i32(%arg0 : memref<?x?xi32>) -> vector<[4]x[4]xi32> {
%c0 = arith.constant 0 : index
%tile = vector.load %arg0[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
@@ -146,12 +128,10 @@ func.func @vector_load_i32(%arg0 : memref<?x?xi32>) -> vector<[4]x[4]xi32> {
// CHECK-LABEL: @vector_load_i64(
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi64>)
-// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i64
-// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i64 to vector<[2]x[2]xi64>
// CHECK-DAG: %[[MIN_SVL_D:.*]] = arith.constant 2 : index
// CHECK: %[[SVL_D:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_D]] : index
-// CHECK: %[[TILE_ID_I32:.*]] = arith.trunci %[[TILE_ID]] : i64 to i32
// CHECK: arm_sme.intr.ld1d.horiz
+// CHECK-SAME: tile_id = 0
func.func @vector_load_i64(%arg0 : memref<?x?xi64>) -> vector<[2]x[2]xi64> {
%c0 = arith.constant 0 : index
%tile = vector.load %arg0[%c0, %c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
@@ -162,12 +142,10 @@ func.func @vector_load_i64(%arg0 : memref<?x?xi64>) -> vector<[2]x[2]xi64> {
// CHECK-LABEL: @vector_load_f16(
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf16>)
-// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i16
-// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xf16>
// CHECK-DAG: %[[MIN_SVL_H:.*]] = arith.constant 8 : index
// CHECK: %[[SVL_H:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_H]] : index
-// CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i16 to i32
// CHECK: arm_sme.intr.ld1h.horiz
+// CHECK-SAME: tile_id = 0
func.func @vector_load_f16(%arg0 : memref<?x?xf16>) -> vector<[8]x[8]xf16> {
%c0 = arith.constant 0 : index
%tile = vector.load %arg0[%c0, %c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
@@ -178,12 +156,10 @@ func.func @vector_load_f16(%arg0 : memref<?x?xf16>) -> vector<[8]x[8]xf16> {
// CHECK-LABEL: @vector_load_bf16(
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xbf16>)
-// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i16
-// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xbf16>
// CHECK-DAG: %[[MIN_SVL_H:.*]] = arith.constant 8 : index
// CHECK: %[[SVL_H:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_H]] : index
-// CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i16 to i32
// CHECK: arm_sme.intr.ld1h.horiz
+// CHECK-SAME: tile_id = 0
func.func @vector_load_bf16(%arg0 : memref<?x?xbf16>) -> vector<[8]x[8]xbf16> {
%c0 = arith.constant 0 : index
%tile = vector.load %arg0[%c0, %c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
@@ -194,13 +170,10 @@ func.func @vector_load_bf16(%arg0 : memref<?x?xbf16>) -> vector<[8]x[8]xbf16> {
// CHECK-LABEL: @vector_load_f32(
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32>)
-// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
-// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xf32>
// CHECK-DAG: %[[MIN_SVL_S:.*]] = arith.constant 4 : index
// CHECK: %[[SVL_S:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_S]] : index
-// CHECK-NOT: arith.extui %[[TILE_ID]]
-// CHECK-NOT: arith.trunci %[[TILE_ID]]
// CHECK: arm_sme.intr.ld1w.horiz
+// CHECK-SAME: tile_id = 0
func.func @vector_load_f32(%arg0 : memref<?x?xf32>) -> vector<[4]x[4]xf32> {
%c0 = arith.constant 0 : index
%tile = vector.load %arg0[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
@@ -211,12 +184,10 @@ func.func @vector_load_f32(%arg0 : memref<?x?xf32>) -> vector<[4]x[4]xf32> {
// CHECK-LABEL: @vector_load_f64(
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf64>)
-// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i64
-// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i64 to vector<[2]x[2]xf64>
// CHECK-DAG: %[[MIN_SVL_D:.*]] = arith.constant 2 : index
// CHECK: %[[SVL_D:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_D]] : index
-// CHECK: %[[TILE_ID_I32:.*]] = arith.trunci %[[TILE_ID]] : i64 to i32
// CHECK: arm_sme.intr.ld1d.horiz
+// CHECK-SAME: tile_id = 0
func.func @vector_load_f64(%arg0 : memref<?x?xf64>) -> vector<[2]x[2]xf64> {
%c0 = arith.constant 0 : index
%tile = vector.load %arg0[%c0, %c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
@@ -227,10 +198,8 @@ func.func @vector_load_f64(%arg0 : memref<?x?xf64>) -> vector<[2]x[2]xf64> {
// CHECK-LABEL: @vector_load_i128(
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi128>)
-// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i128
-// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i128 to vector<[1]x[1]xi128>
-// CHECK: %[[TILE_ID_I32:.*]] = arith.trunci %[[TILE_ID]] : i128 to i32
// CHECK: arm_sme.intr.ld1q.horiz
+// CHECK-SAME: tile_id = 0
func.func @vector_load_i128(%arg0 : memref<?x?xi128>) -> vector<[1]x[1]xi128> {
%c0 = arith.constant 0 : index
%tile = vector.load %arg0[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
@@ -244,7 +213,6 @@ func.func @vector_load_i128(%arg0 : memref<?x?xi128>) -> vector<[1]x[1]xi128> {
// -----
// CHECK-LABEL: @vector_store_i8(
-// CHECK-SAME: %[[TILE:.*]]: vector<[16]x[16]xi8>,
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi8>)
// CHECK-DAG: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
@@ -256,19 +224,18 @@ func.func @vector_load_i128(%arg0 : memref<?x?xi128>) -> vector<[1]x[1]xi128> {
// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE]], %[[MIN_SVL_B]] : index
// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] {
// CHECK: %[[TILE_SLICE_I64:.*]] = builtin.unrealized_conversion_cast %[[TILE_SLICE]] : index to i64
-// CHECK-NEXT: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[16]x[16]xi8> to i8
// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]] : i64
// CHECK-NEXT: %[[OFF1:.*]] = llvm.add %[[OFF0]], %[[C0_I64]] : i64
// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
// CHECK-NEXT: %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32
-// CHECK-NEXT: %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i8 to i32
-// CHECK-NEXT: "arm_sme.intr.st1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+// CHECK-NEXT: "arm_sme.intr.st1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_SLICE_I32]]) <{tile_id = 0 : i32}> : (vector<[16]xi1>, !llvm.ptr, i32) -> ()
// CHECK-NEXT: }
// CHECK-NEXT: return
-func.func @vector_store_i8(%tile : vector<[16]x[16]xi8>, %arg0 : memref<?x?xi8>) {
+func.func @vector_store_i8(%arg0 : memref<?x?xi8>) {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
return
}
@@ -276,15 +243,14 @@ func.func @vector_store_i8(%tile : vector<[16]x[16]xi8>, %arg0 : memref<?x?xi8>)
// -----
// CHECK-LABEL: @vector_store_i16(
-// CHECK-SAME: %[[TILE:.*]]: vector<[8]x[8]xi16>,
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi16>)
// CHECK: %[[MIN_SVL_H:.*]] = arith.constant 8 : index
// CHECK: %[[SVL_H:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_H]] : index
-// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[8]x[8]xi16> to i16
-// CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32
// CHECK: arm_sme.intr.st1h.horiz
-func.func @vector_store_i16(%tile : vector<[8]x[8]xi16>, %arg0 : memref<?x?xi16>) {
+// CHECK-SAME: tile_id = 0
+func.func @vector_store_i16(%arg0 : memref<?x?xi16>) {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
return
}
@@ -292,16 +258,14 @@ func.func @vector_store_i16(%tile : vector<[8]x[8]xi16>, %arg0 : memref<?x?xi16>
// -----
// CHECK-LABEL: @vector_store_i32(
-// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xi32>,
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi32>)
// CHECK: %[[MIN_SVL_S:.*]] = arith.constant 4 : index
// CHECK: %[[SVL_S:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_S]] : index
-// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
-// CHECK-NOT: arith.extui %[[CAST_VECTOR_TO_TILE]]
-// CHECK-NOT: arith.trunci %[[CAST_VECTOR_TO_TILE]]
// CHECK: arm_sme.intr.st1w.horiz
-func.func @vector_store_i32(%tile : vector<[4]x[4]xi32>, %arg0 : memref<?x?xi32>) {
+// CHECK-SAME: tile_id = 0
+func.func @vector_store_i32(%arg0 : memref<?x?xi32>) {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
return
}
@@ -309,15 +273,14 @@ func.func @vector_store_i32(%tile : vector<[4]x[4]xi32>, %arg0 : memref<?x?xi32>
// -----
// CHECK-LABEL: @vector_store_i64(
-// CHECK-SAME: %[[TILE:.*]]: vector<[2]x[2]xi64>,
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi64>)
// CHECK: %[[MIN_SVL_D:.*]] = arith.constant 2 : index
// CHECK: %[[SVL_D:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_D]] : index
-// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[2]x[2]xi64> to i64
-// CHECK: %[[TILE_ID_I32:.*]] = arith.trunci %[[CAST_VECTOR_TO_TILE]] : i64 to i32
// CHECK: arm_sme.intr.st1d.horiz
-func.func @vector_store_i64(%tile : vector<[2]x[2]xi64>, %arg0 : memref<?x?xi64>) {
+// CHECK-SAME: tile_id = 0
+func.func @vector_store_i64(%arg0 : memref<?x?xi64>) {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
return
}
@@ -325,15 +288,14 @@ func.func @vector_store_i64(%tile : vector<[2]x[2]xi64>, %arg0 : memref<?x?xi64>
// -----
// CHECK-LABEL: @vector_store_f16(
-// CHECK-SAME: %[[TILE:.*]]: vector<[8]x[8]xf16>,
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf16>)
// CHECK: %[[MIN_SVL_H:.*]] = arith.constant 8 : index
// CHECK: %[[SVL_H:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_H]] : index
-// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[8]x[8]xf16> to i16
-// CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32
// CHECK: arm_sme.intr.st1h.horiz
-func.func @vector_store_f16(%tile : vector<[8]x[8]xf16>, %arg0 : memref<?x?xf16>) {
+// CHECK-SAME: tile_id = 0
+func.func @vector_store_f16(%arg0 : memref<?x?xf16>) {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
vector.store %tile, %arg0[%c0, %c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
return
}
@@ -341,31 +303,28 @@ func.func @vector_store_f16(%tile : vector<[8]x[8]xf16>, %arg0 : memref<?x?xf16>
// -----
// CHECK-LABEL: @vector_store_bf16(
-// CHECK-SAME: %[[TILE:.*]]: vector<[8]x[8]xbf16>,
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xbf16>)
// CHECK: %[[MIN_SVL_H:.*]] = arith.constant 8 : index
// CHECK: %[[SVL_H:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_H]] : index
-// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[8]x[8]xbf16> to i16
-// CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32
// CHECK: arm_sme.intr.st1h.horiz
-func.func @vector_store_bf16(%tile : vector<[8]x[8]xbf16>, %arg0 : memref<?x?xbf16>) {
+// CHECK-SAME: tile_id = 0
+func.func @vector_store_bf16(%arg0 : memref<?x?xbf16>) {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
vector.store %tile, %arg0[%c0, %c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
return
}
// -----
// CHECK-LABEL: @vector_store_f32(
-// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xf32>,
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32>)
// CHECK: %[[MIN_SVL_S:.*]] = arith.constant 4 : index
// CHECK: %[[SVL_S:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_S]] : index
-// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xf32> to i32
-// CHECK-NOT: arith.extui %[[CAST_VECTOR_TO_TILE]]
-// CHECK-NOT: arith.trunci %[[CAST_VECTOR_TO_TILE]]
// CHECK: arm_sme.intr.st1w.horiz
-func.func @vector_store_f32(%tile : vector<[4]x[4]xf32>, %arg0 : memref<?x?xf32>) {
+// CHECK-SAME: tile_id = 0
+func.func @vector_store_f32(%arg0 : memref<?x?xf32>) {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
vector.store %tile, %arg0[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
return
}
@@ -373,15 +332,14 @@ func.func @vector_store_f32(%tile : vector<[4]x[4]xf32>, %arg0 : memref<?x?xf32>
// -----
// CHECK-LABEL: @vector_store_f64(
-// CHECK-SAME: %[[TILE:.*]]: vector<[2]x[2]xf64>,
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf64>)
// CHECK: %[[MIN_SVL_D:.*]] = arith.constant 2 : index
// CHECK: %[[SVL_D:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_D]] : index
-// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[2]x[2]xf64> to i64
-// CHECK: %[[TILE_ID_I32:.*]] = arith.trunci %[[CAST_VECTOR_TO_TILE]] : i64 to i32
// CHECK: arm_sme.intr.st1d.horiz
-func.func @vector_store_f64(%tile : vector<[2]x[2]xf64>, %arg0 : memref<?x?xf64>) {
+// CHECK-SAME: tile_id = 0
+func.func @vector_store_f64(%arg0 : memref<?x?xf64>) {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
vector.store %tile, %arg0[%c0, %c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
return
}
@@ -389,13 +347,12 @@ func.func @vector_store_f64(%tile : vector<[2]x[2]xf64>, %arg0 : memref<?x?xf64>
// -----
// CHECK-LABEL: @vector_store_i128(
-// CHECK-SAME: %[[TILE:.*]]: vector<[1]x[1]xi128>,
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi128>)
-// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[1]x[1]xi128> to i128
-// CHECK: %[[TILE_ID_I32:.*]] = arith.trunci %[[CAST_VECTOR_TO_TILE]] : i128 to i32
// CHECK: arm_sme.intr.st1q.horiz
-func.func @vector_store_i128(%tile : vector<[1]x[1]xi128>, %arg0 : memref<?x?xi128>) {
+// CHECK-SAME: tile_id = 0
+func.func @vector_store_i128(%arg0 : memref<?x?xi128>) {
%c0 = arith.constant 0 : index
+ %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
return
}
@@ -407,12 +364,11 @@ func.func @vector_store_i128(%tile : vector<[1]x[1]xi128>, %arg0 : memref<?x?xi1
// -----
// CHECK-LABEL: @vector_outerproduct_add_f16
-// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>, %[[ACC:.*]]: vector<[8]x[8]xf16>)
-func.func @vector_outerproduct_add_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %acc : vector<[8]x[8]xf16>) {
+// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>)
+func.func @vector_outerproduct_add_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>) {
// CHECK: %[[PTRUE_ALL:.*]] = arith.constant dense<true> : vector<[8]xi1>
- // CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[ACC]] : vector<[8]x[8]xf16> to i16
- // CHECK: %[[CAST_VECTOR_TO_TILE_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32
- // CHECK: "arm_sme.intr.mopa"(%[[CAST_VECTOR_TO_TILE_I32]], %[[PTRUE_ALL]], %[[PTRUE_ALL]], %[[LHS]], %[[RHS]]) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>)
+ // CHECK: "arm_sme.intr.mopa"(%[[PTRUE_ALL]], %[[PTRUE_ALL]], %[[LHS]], %[[RHS]]) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>)
+ %acc = arm_sme.get_tile : vector<[8]x[8]xf16>
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xf16>, vector<[8]xf16>
"prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> ()
}
@@ -420,8 +376,9 @@ func.func @vector_outerproduct_add_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]
// -----
// CHECK-LABEL: @vector_outerproduct_add_bf16
-func.func @vector_outerproduct_add_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %acc : vector<[8]x[8]xbf16>) {
- // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>)
+func.func @vector_outerproduct_add_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>) {
+ // CHECK: "arm_sme.intr.mopa"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>)
+ %acc = arm_sme.get_tile : vector<[8]x[8]xbf16>
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xbf16>, vector<[8]xbf16>
"prevent.dce"(%0) : (vector<[8]x[8]xbf16>) -> ()
}
@@ -429,10 +386,9 @@ func.func @vector_outerproduct_add_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[
// -----
// CHECK-LABEL: @vector_outerproduct_add_f32
-func.func @vector_outerproduct_add_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>) {
- // CHECK-NOT: arith.extui
- // CHECK-NOT: arith.trunci
- // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>)
+func.func @vector_outerproduct_add_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>) {
+ // CHECK: "arm_sme.intr.mopa"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>)
+ %acc = arm_sme.get_tile : vector<[4]x[4]xf32>
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32>
"prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
}
@@ -440,9 +396,9 @@ func.func @vector_outerproduct_add_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]
// -----
// CHECK-LABEL: @vector_outerproduct_add_f64
-func.func @vector_outerproduct_add_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>) {
- // CHECK: arith.trunci {{.*}} : i64 to i32
- // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
+func.func @vector_outerproduct_add_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>) {
+ // CHECK: "arm_sme.intr.mopa"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
+ %acc = arm_sme.get_tile : vector<[2]x[2]xf64>
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64>
"prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
}
@@ -451,8 +407,8 @@ func.func @vector_outerproduct_add_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]
// CHECK-LABEL: @vector_outerproduct_no_accumulator
func.func @vector_outerproduct_no_accumulator(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>) {
- // CHECK: "arm_sme.intr.zero"({{.*}}) : (i32) -> ()
- // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
+ // CHECK: "arm_sme.intr.zero"() <{tile_mask = 1 : i32}> : () -> ()
+ // CHECK: "arm_sme.intr.mopa"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
%0 = vector.outerproduct %lhs, %rhs {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64>
"prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
}
@@ -460,12 +416,12 @@ func.func @vector_outerproduct_no_accumulator(%lhs : vector<[2]xf64>, %rhs : vec
// -----
// CHECK-LABEL: @vector_outerproduct_masked_f32
-// CHECK-SAME: (%[[LHS:.*]]: vector<[4]xf32>, %[[RHS:.*]]: vector<[4]xf32>, %[[ACC:.*]]: vector<[4]x[4]xf32>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
-func.func @vector_outerproduct_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %dim0 : index, %dim1 : index) {
+// CHECK-SAME: (%[[LHS:.*]]: vector<[4]xf32>, %[[RHS:.*]]: vector<[4]xf32>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
+func.func @vector_outerproduct_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %dim0 : index, %dim1 : index) {
// CHECK: %[[LHS_MASK:.*]] = vector.create_mask %[[DIM0]] : vector<[4]xi1>
// CHECK: %[[RHS_MASK:.*]] = vector.create_mask %[[DIM1]] : vector<[4]xi1>
- // CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[ACC]] : vector<[4]x[4]xf32> to i32
- // CHECK: "arm_sme.intr.mopa"(%[[CAST_VECTOR_TO_TILE]], %[[LHS_MASK]], %[[RHS_MASK]], %[[LHS]], %[[RHS]]) : (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>)
+ // CHECK: "arm_sme.intr.mopa"(%[[LHS_MASK]], %[[RHS_MASK]], %[[LHS]], %[[RHS]]) <{tile_id = 0 : i32}> : (vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>)
+ %acc = arm_sme.get_tile : vector<[4]x[4]xf32>
%mask = vector.create_mask %dim0, %dim1 : vector<[4]x[4]xi1>
%result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
"prevent.dce"(%result) : (vector<[4]x[4]xf32>) -> ()
@@ -474,11 +430,12 @@ func.func @vector_outerproduct_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<
// -----
// CHECK-LABEL: @vector_outerproduct_masked_f16
-// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>, %[[ACC:.*]]: vector<[8]x[8]xf16>,
-func.func @vector_outerproduct_masked_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %acc : vector<[8]x[8]xf16>, %dim0 : index, %dim1 : index) {
+// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>,
+func.func @vector_outerproduct_masked_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %dim0 : index, %dim1 : index) {
// CHECK: vector.create_mask {{.*}} : vector<[8]xi1>
// CHECK: vector.create_mask {{.*}} : vector<[8]xi1>
- // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>)
+ // CHECK: "arm_sme.intr.mopa"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>)
+ %acc = arm_sme.get_tile : vector<[8]x[8]xf16>
%mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
%result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xf16>, vector<[8]xf16> } : vector<[8]x[8]xi1> -> vector<[8]x[8]xf16>
"prevent.dce"(%result) : (vector<[8]x[8]xf16>) -> ()
@@ -487,11 +444,12 @@ func.func @vector_outerproduct_masked_f16(%lhs : vector<[8]xf16>, %rhs : vector<
// -----
// CHECK-LABEL: @vector_outerproduct_masked_bf16
-// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xbf16>, %[[RHS:.*]]: vector<[8]xbf16>, %[[ACC:.*]]: vector<[8]x[8]xbf16>,
-func.func @vector_outerproduct_masked_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %acc : vector<[8]x[8]xbf16>, %dim0 : index, %dim1 : index) {
+// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xbf16>, %[[RHS:.*]]: vector<[8]xbf16>
+func.func @vector_outerproduct_masked_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %dim0 : index, %dim1 : index) {
// CHECK: vector.create_mask {{.*}} : vector<[8]xi1>
// CHECK: vector.create_mask {{.*}} : vector<[8]xi1>
- // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>)
+ // CHECK: "arm_sme.intr.mopa"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>)
+ %acc = arm_sme.get_tile : vector<[8]x[8]xbf16>
%mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
%result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xbf16>, vector<[8]xbf16> } : vector<[8]x[8]xi1> -> vector<[8]x[8]xbf16>
"prevent.dce"(%result) : (vector<[8]x[8]xbf16>) -> ()
@@ -500,11 +458,12 @@ func.func @vector_outerproduct_masked_bf16(%lhs : vector<[8]xbf16>, %rhs : vecto
// -----
// CHECK-LABEL: @vector_outerproduct_masked_f64
-// CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>, %[[ACC:.*]]: vector<[2]x[2]xf64>,
-func.func @vector_outerproduct_masked_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>, %dim0 : index, %dim1 : index) {
+// CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>,
+func.func @vector_outerproduct_masked_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %dim0 : index, %dim1 : index) {
// CHECK: vector.create_mask {{.*}} : vector<[2]xi1>
// CHECK: vector.create_mask {{.*}} : vector<[2]xi1>
- // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
+ // CHECK: "arm_sme.intr.mopa"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
+ %acc = arm_sme.get_tile : vector<[2]x[2]xf64>
%mask = vector.create_mask %dim0, %dim1 : vector<[2]x[2]xi1>
%result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64> } : vector<[2]x[2]xi1> -> vector<[2]x[2]xf64>
"prevent.dce"(%result) : (vector<[2]x[2]xf64>) -> ()
@@ -520,7 +479,8 @@ func.func @vector_outerproduct_unsupported_axpy(%lhs : vector<[2]xf64>, %rhs : f
// -----
-func.func @vector_outerproduct_unsupported_type(%lhs : vector<[16]xi8>, %rhs : vector<[16]xi8>, %acc : vector<[16]x[16]xi8>) {
+func.func @vector_outerproduct_unsupported_type(%lhs : vector<[16]xi8>, %rhs : vector<[16]xi8>) {
+ %acc = arm_sme.get_tile : vector<[16]x[16]xi8>
// expected-error at +2 {{failed to legalize operation 'arm_sme.outerproduct'}}
// expected-error at +1 {{unsupported type}}
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[16]xi8>, vector<[16]xi8>
@@ -529,7 +489,8 @@ func.func @vector_outerproduct_unsupported_type(%lhs : vector<[16]xi8>, %rhs : v
// -----
-func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>) {
+func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>) {
+ %acc = arm_sme.get_tile : vector<[2]x[2]xf64>
// expected-error at +1 {{unsupported kind}}
%0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, vector<[2]xf64>
"prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
@@ -537,8 +498,9 @@ func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : v
// -----
-func.func @vector_outerproduct_unknown_mask(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %mask : vector<[4]x[4]xi1>) {
+func.func @vector_outerproduct_unknown_mask(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %mask : vector<[4]x[4]xi1>) {
// CHECK: vector.outerproduct
+ %acc = arm_sme.get_tile : vector<[4]x[4]xf32>
%0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
"prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
}
@@ -550,14 +512,13 @@ func.func @vector_outerproduct_unknown_mask(%lhs : vector<[4]xf32>, %rhs : vecto
// -----
// CHECK-LABEL: @vector_insert_slice_i32(
-// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xi32>,
// CHECK-SAME: %[[SLICE:.*]]: vector<[4]xi32>,
// CHECK-SAME: %[[INDEX:.*]]: index)
-func.func @vector_insert_slice_i32(%tile: vector<[4]x[4]xi32>, %slice: vector<[4]xi32>, %row: index) -> vector<[4]x[4]xi32>{
+func.func @vector_insert_slice_i32(%slice: vector<[4]xi32>, %row: index) -> vector<[4]x[4]xi32>{
// CHECK-NEXT: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
- // CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
- // CHECK-NEXT: %[[TILE_SLICE_INDEX:.*]] = arith.index_castui %[[INDEX]] : index to i32
- // CHECK-NEXT: "arm_sme.intr.write.horiz"(%[[TILE_ID]], %[[TILE_SLICE_INDEX]], %[[PTRUE]], %[[SLICE]]) : (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+ // CHECK: %[[TILE_SLICE_INDEX:.*]] = arith.index_castui %[[INDEX]] : index to i32
+ // CHECK-NEXT: "arm_sme.intr.write.horiz"(%[[TILE_SLICE_INDEX]], %[[PTRUE]], %[[SLICE]]) <{tile_id = 0 : i32}> : (i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+ %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
%new_tile = vector.insert %slice, %tile[%row] : vector<[4]xi32> into vector<[4]x[4]xi32>
return %new_tile : vector<[4]x[4]xi32>
}
@@ -565,8 +526,9 @@ func.func @vector_insert_slice_i32(%tile: vector<[4]x[4]xi32>, %slice: vector<[4
// -----
// CHECK-LABEL: @vector_insert_slice_i8
-func.func @vector_insert_slice_i8(%tile: vector<[16]x[16]xi8>, %slice: vector<[16]xi8>, %row: index) -> vector<[16]x[16]xi8> {
- // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
+func.func @vector_insert_slice_i8(%slice: vector<[16]xi8>, %row: index) -> vector<[16]x[16]xi8> {
+ // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
+ %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
%new_tile = vector.insert %slice, %tile[%row] : vector<[16]xi8> into vector<[16]x[16]xi8>
return %new_tile : vector<[16]x[16]xi8>
}
@@ -574,8 +536,9 @@ func.func @vector_insert_slice_i8(%tile: vector<[16]x[16]xi8>, %slice: vector<[1
// -----
// CHECK-LABEL: @vector_insert_slice_i16
-func.func @vector_insert_slice_i16(%tile: vector<[8]x[8]xi16>, %slice: vector<[8]xi16>, %row: index) -> vector<[8]x[8]xi16> {
- // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
+func.func @vector_insert_slice_i16(%slice: vector<[8]xi16>, %row: index) -> vector<[8]x[8]xi16> {
+ // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
+ %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
%new_tile = vector.insert %slice, %tile[%row] : vector<[8]xi16> into vector<[8]x[8]xi16>
return %new_tile : vector<[8]x[8]xi16>
}
@@ -583,8 +546,9 @@ func.func @vector_insert_slice_i16(%tile: vector<[8]x[8]xi16>, %slice: vector<[8
// -----
// CHECK-LABEL: @vector_insert_slice_i64
-func.func @vector_insert_slice_i64(%tile: vector<[2]x[2]xi64>, %slice: vector<[2]xi64>, %row: index) -> vector<[2]x[2]xi64> {
- // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
+func.func @vector_insert_slice_i64(%slice: vector<[2]xi64>, %row: index) -> vector<[2]x[2]xi64> {
+ // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
+ %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
%new_tile = vector.insert %slice, %tile[%row] : vector<[2]xi64> into vector<[2]x[2]xi64>
return %new_tile : vector<[2]x[2]xi64>
}
@@ -592,8 +556,9 @@ func.func @vector_insert_slice_i64(%tile: vector<[2]x[2]xi64>, %slice: vector<[2
// -----
// CHECK-LABEL: @vector_insert_slice_i128
-func.func @vector_insert_slice_i128(%tile: vector<[1]x[1]xi128>, %slice: vector<[1]xi128>, %row: index) -> vector<[1]x[1]xi128> {
- // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
+func.func @vector_insert_slice_i128(%slice: vector<[1]xi128>, %row: index) -> vector<[1]x[1]xi128> {
+ // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
+ %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
%new_tile = vector.insert %slice, %tile[%row] : vector<[1]xi128> into vector<[1]x[1]xi128>
return %new_tile : vector<[1]x[1]xi128>
}
@@ -601,8 +566,9 @@ func.func @vector_insert_slice_i128(%tile: vector<[1]x[1]xi128>, %slice: vector<
// -----
// CHECK-LABEL: @vector_insert_slice_f16
-func.func @vector_insert_slice_f16(%tile: vector<[8]x[8]xf16>, %slice: vector<[8]xf16>, %row: index) -> vector<[8]x[8]xf16> {
- // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
+func.func @vector_insert_slice_f16(%slice: vector<[8]xf16>, %row: index) -> vector<[8]x[8]xf16> {
+ // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
+ %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
%new_tile = vector.insert %slice, %tile[%row] : vector<[8]xf16> into vector<[8]x[8]xf16>
return %new_tile : vector<[8]x[8]xf16>
}
@@ -610,8 +576,9 @@ func.func @vector_insert_slice_f16(%tile: vector<[8]x[8]xf16>, %slice: vector<[8
// -----
// CHECK-LABEL: @vector_insert_slice_bf16
-func.func @vector_insert_slice_bf16(%tile: vector<[8]x[8]xbf16>, %slice: vector<[8]xbf16>, %row: index) -> vector<[8]x[8]xbf16> {
- // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
+func.func @vector_insert_slice_bf16(%slice: vector<[8]xbf16>, %row: index) -> vector<[8]x[8]xbf16> {
+ // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
+ %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
%new_tile = vector.insert %slice, %tile[%row] : vector<[8]xbf16> into vector<[8]x[8]xbf16>
return %new_tile : vector<[8]x[8]xbf16>
}
@@ -619,8 +586,9 @@ func.func @vector_insert_slice_bf16(%tile: vector<[8]x[8]xbf16>, %slice: vector<
// -----
// CHECK-LABEL: @vector_insert_slice_f32
-func.func @vector_insert_slice_f32(%tile: vector<[4]x[4]xf32>, %slice: vector<[4]xf32>, %row: index) -> vector<[4]x[4]xf32> {
- // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
+func.func @vector_insert_slice_f32(%slice: vector<[4]xf32>, %row: index) -> vector<[4]x[4]xf32> {
+ // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
+ %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
%new_tile = vector.insert %slice, %tile[%row] : vector<[4]xf32> into vector<[4]x[4]xf32>
return %new_tile : vector<[4]x[4]xf32>
}
@@ -628,8 +596,9 @@ func.func @vector_insert_slice_f32(%tile: vector<[4]x[4]xf32>, %slice: vector<[4
// -----
// CHECK-LABEL: @vector_insert_slice_f64
-func.func @vector_insert_slice_f64(%tile: vector<[2]x[2]xf64>, %slice: vector<[2]xf64>, %row: index) -> vector<[2]x[2]xf64> {
- // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
+func.func @vector_insert_slice_f64(%slice: vector<[2]xf64>, %row: index) -> vector<[2]x[2]xf64> {
+ // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
+ %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
%new_tile = vector.insert %slice, %tile[%row] : vector<[2]xf64> into vector<[2]x[2]xf64>
return %new_tile : vector<[2]x[2]xf64>
}
@@ -637,19 +606,18 @@ func.func @vector_insert_slice_f64(%tile: vector<[2]x[2]xf64>, %slice: vector<[2
// -----
// CHECK-LABEL: @vector_insert_element_i32(
-// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xi32>,
// CHECK-SAME: %[[EL:.*]]: i32,
// CHECK-SAME: %[[ROW:.*]]: index,
// CHECK-SAME: %[[COL:.*]]: index)
-func.func @vector_insert_element_i32(%tile: vector<[4]x[4]xi32>, %el: i32, %row: index, %col: index) -> vector<[4]x[4]xi32> {
- // CHECK-NEXT: %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32>
- // CHECK-NEXT: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
- // CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
- // CHECK-NEXT: %[[ROW_I32:.*]] = arith.index_cast %[[ROW]] : index to i32
- // CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[TILE_ID]], %[[ROW_I32]]) : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+func.func @vector_insert_element_i32(%el: i32, %row: index, %col: index) -> vector<[4]x[4]xi32> {
+ // CHECK-DAG: %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32>
+ // CHECK-DAG: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
+ // CHECK-DAG: %[[ROW_I32:.*]] = arith.index_cast %[[ROW]] : index to i32
+ // CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[ROW_I32]]) <{tile_id = 0 : i32}> : (vector<[4]xi32>, vector<[4]xi1>, i32) -> vector<[4]xi32>
// CHECK-NEXT: %[[NEW_SLICE:.*]] = vector.insert %[[EL]], %[[SLICE]] [%[[COL]]] : i32 into vector<[4]xi32>
// CHECK-NEXT: %[[SLICE_INDEX:.*]] = arith.index_castui %[[ROW]] : index to i32
- // CHECK-NEXT: "arm_sme.intr.write.horiz"(%[[TILE_ID]], %[[SLICE_INDEX]], %[[PTRUE]], %[[NEW_SLICE]]) : (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+ // CHECK-NEXT: "arm_sme.intr.write.horiz"(%[[SLICE_INDEX]], %[[PTRUE]], %[[NEW_SLICE]]) <{tile_id = 0 : i32}> : (i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+ %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
%new_tile = vector.insert %el, %tile[%row, %col] : i32 into vector<[4]x[4]xi32>
return %new_tile : vector<[4]x[4]xi32>
}
@@ -657,9 +625,10 @@ func.func @vector_insert_element_i32(%tile: vector<[4]x[4]xi32>, %el: i32, %row:
// -----
// CHECK-LABEL: @vector_insert_element_i8
-func.func @vector_insert_element_i8(%tile: vector<[16]x[16]xi8>, %el: i8, %row: index, %col: index) -> vector<[16]x[16]xi8> {
- // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
- // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
+func.func @vector_insert_element_i8(%el: i8, %row: index, %col: index) -> vector<[16]x[16]xi8> {
+ // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi8>, vector<[16]xi1>, i32) -> vector<[16]xi8>
+ // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
+ %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
%new_tile = vector.insert %el, %tile[%row, %col] : i8 into vector<[16]x[16]xi8>
return %new_tile : vector<[16]x[16]xi8>
}
@@ -667,9 +636,10 @@ func.func @vector_insert_element_i8(%tile: vector<[16]x[16]xi8>, %el: i8, %row:
// -----
// CHECK-LABEL: @vector_insert_element_i16
-func.func @vector_insert_element_i16(%tile: vector<[8]x[8]xi16>, %el: i16, %row: index, %col: index) -> vector<[8]x[8]xi16> {
- // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
- // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
+func.func @vector_insert_element_i16(%el: i16, %row: index, %col: index) -> vector<[8]x[8]xi16> {
+ // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi16>, vector<[8]xi1>, i32) -> vector<[8]xi16>
+ // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
+ %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
%new_tile = vector.insert %el, %tile[%row, %col] : i16 into vector<[8]x[8]xi16>
return %new_tile : vector<[8]x[8]xi16>
}
@@ -677,9 +647,10 @@ func.func @vector_insert_element_i16(%tile: vector<[8]x[8]xi16>, %el: i16, %row:
// -----
// CHECK-LABEL: @vector_insert_element_i64
-func.func @vector_insert_element_i64(%tile: vector<[2]x[2]xi64>, %el: i64, %row: index, %col: index) -> vector<[2]x[2]xi64> {
- // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
- // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
+func.func @vector_insert_element_i64(%el: i64, %row: index, %col: index) -> vector<[2]x[2]xi64> {
+ // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi64>, vector<[2]xi1>, i32) -> vector<[2]xi64>
+ // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
+ %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
%new_tile = vector.insert %el, %tile[%row, %col] : i64 into vector<[2]x[2]xi64>
return %new_tile : vector<[2]x[2]xi64>
}
@@ -687,9 +658,10 @@ func.func @vector_insert_element_i64(%tile: vector<[2]x[2]xi64>, %el: i64, %row:
// -----
// CHECK-LABEL: @vector_insert_element_i128
-func.func @vector_insert_element_i128(%tile: vector<[1]x[1]xi128>, %el: i128, %row: index, %col: index) -> vector<[1]x[1]xi128> {
- // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
- // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
+func.func @vector_insert_element_i128(%el: i128, %row: index, %col: index) -> vector<[1]x[1]xi128> {
+ // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[1]xi128>, vector<[1]xi1>, i32) -> vector<[1]xi128>
+ // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
+ %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
%new_tile = vector.insert %el, %tile[%row, %col] : i128 into vector<[1]x[1]xi128>
return %new_tile : vector<[1]x[1]xi128>
}
@@ -697,9 +669,10 @@ func.func @vector_insert_element_i128(%tile: vector<[1]x[1]xi128>, %el: i128, %r
// -----
// CHECK-LABEL: @vector_insert_element_f16
-func.func @vector_insert_element_f16(%tile: vector<[8]x[8]xf16>, %el: f16, %row: index, %col: index) -> vector<[8]x[8]xf16> {
- // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
- // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
+func.func @vector_insert_element_f16(%el: f16, %row: index, %col: index) -> vector<[8]x[8]xf16> {
+ // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xf16>, vector<[8]xi1>, i32) -> vector<[8]xf16>
+ // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
+ %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
%new_tile = vector.insert %el, %tile[%row, %col] : f16 into vector<[8]x[8]xf16>
return %new_tile : vector<[8]x[8]xf16>
}
@@ -707,9 +680,10 @@ func.func @vector_insert_element_f16(%tile: vector<[8]x[8]xf16>, %el: f16, %row:
// -----
// CHECK-LABEL: @vector_insert_element_bf16
-func.func @vector_insert_element_bf16(%tile: vector<[8]x[8]xbf16>, %el: bf16, %row: index, %col: index) -> vector<[8]x[8]xbf16> {
- // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
- // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
+func.func @vector_insert_element_bf16(%el: bf16, %row: index, %col: index) -> vector<[8]x[8]xbf16> {
+ // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xbf16>, vector<[8]xi1>, i32) -> vector<[8]xbf16>
+ // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
+ %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
%new_tile = vector.insert %el, %tile[%row, %col] : bf16 into vector<[8]x[8]xbf16>
return %new_tile : vector<[8]x[8]xbf16>
}
@@ -717,9 +691,10 @@ func.func @vector_insert_element_bf16(%tile: vector<[8]x[8]xbf16>, %el: bf16, %r
// -----
// CHECK-LABEL: @vector_insert_element_f32
-func.func @vector_insert_element_f32(%tile: vector<[4]x[4]xf32>, %el: f32, %row: index, %col: index) -> vector<[4]x[4]xf32> {
- // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
- // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
+func.func @vector_insert_element_f32(%el: f32, %row: index, %col: index) -> vector<[4]x[4]xf32> {
+ // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xf32>, vector<[4]xi1>, i32) -> vector<[4]xf32>
+ // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
+ %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
%new_tile = vector.insert %el, %tile[%row, %col] : f32 into vector<[4]x[4]xf32>
return %new_tile : vector<[4]x[4]xf32>
}
@@ -727,9 +702,10 @@ func.func @vector_insert_element_f32(%tile: vector<[4]x[4]xf32>, %el: f32, %row:
// -----
// CHECK-LABEL: @vector_insert_element_f64
-func.func @vector_insert_element_f64(%tile: vector<[2]x[2]xf64>, %el: f64, %row: index, %col: index) -> vector<[2]x[2]xf64> {
- // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
- // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
+func.func @vector_insert_element_f64(%el: f64, %row: index, %col: index) -> vector<[2]x[2]xf64> {
+ // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xf64>, vector<[2]xi1>, i32) -> vector<[2]xf64>
+ // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
+ %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
%new_tile = vector.insert %el, %tile[%row, %col] : f64 into vector<[2]x[2]xf64>
return %new_tile : vector<[2]x[2]xf64>
}
@@ -741,14 +717,13 @@ func.func @vector_insert_element_f64(%tile: vector<[2]x[2]xf64>, %el: f64, %row:
// -----
// CHECK-LABEL: @vector_extract_slice_i32(
-// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xi32>,
// CHECK-SAME: %[[INDEX:.*]]: index)
-func.func @vector_extract_slice_i32(%tile: vector<[4]x[4]xi32>, %row: index) -> vector<[4]xi32> {
- // CHECK-NEXT: %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32>
+func.func @vector_extract_slice_i32(%row: index) -> vector<[4]xi32> {
// CHECK-NEXT: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
- // CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
+ // CHECK-NEXT: %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32>
// CHECK-NEXT: %[[TILE_SLICE_INDEX:.*]] = arith.index_cast %[[INDEX]] : index to i32
- // CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[TILE_ID]], %[[TILE_SLICE_INDEX]]) : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+ // CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[TILE_SLICE_INDEX]]) <{tile_id = 0 : i32}> : (vector<[4]xi32>, vector<[4]xi1>, i32) -> vector<[4]xi32>
+ %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
%slice = vector.extract %tile[%row] : vector<[4]xi32> from vector<[4]x[4]xi32>
return %slice : vector<[4]xi32>
}
@@ -756,8 +731,9 @@ func.func @vector_extract_slice_i32(%tile: vector<[4]x[4]xi32>, %row: index) ->
// -----
// CHECK-LABEL: @vector_extract_slice_i8
-func.func @vector_extract_slice_i8(%tile: vector<[16]x[16]xi8>, %row: index) -> vector<[16]xi8> {
- // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
+func.func @vector_extract_slice_i8(%row: index) -> vector<[16]xi8> {
+ // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi8>, vector<[16]xi1>, i32) -> vector<[16]xi8>
+ %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
%slice = vector.extract %tile[%row] : vector<[16]xi8> from vector<[16]x[16]xi8>
return %slice : vector<[16]xi8>
}
@@ -765,8 +741,9 @@ func.func @vector_extract_slice_i8(%tile: vector<[16]x[16]xi8>, %row: index) ->
// -----
// CHECK-LABEL: @vector_extract_slice_i16
-func.func @vector_extract_slice_i16(%tile: vector<[8]x[8]xi16>, %row: index) -> vector<[8]xi16> {
- // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
+func.func @vector_extract_slice_i16(%row: index) -> vector<[8]xi16> {
+ // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi16>, vector<[8]xi1>, i32) -> vector<[8]xi16>
+ %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
%slice = vector.extract %tile[%row] : vector<[8]xi16> from vector<[8]x[8]xi16>
return %slice : vector<[8]xi16>
}
@@ -774,8 +751,9 @@ func.func @vector_extract_slice_i16(%tile: vector<[8]x[8]xi16>, %row: index) ->
// -----
// CHECK-LABEL: @vector_extract_slice_i64
-func.func @vector_extract_slice_i64(%tile: vector<[2]x[2]xi64>, %row: index) -> vector<[2]xi64> {
- // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
+func.func @vector_extract_slice_i64(%row: index) -> vector<[2]xi64> {
+ // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi64>, vector<[2]xi1>, i32) -> vector<[2]xi64>
+ %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
%slice = vector.extract %tile[%row] : vector<[2]xi64> from vector<[2]x[2]xi64>
return %slice : vector<[2]xi64>
}
@@ -783,8 +761,9 @@ func.func @vector_extract_slice_i64(%tile: vector<[2]x[2]xi64>, %row: index) ->
// -----
// CHECK-LABEL: @vector_extract_slice_i128
-func.func @vector_extract_slice_i128(%tile: vector<[1]x[1]xi128>, %row: index) -> vector<[1]xi128> {
- // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+func.func @vector_extract_slice_i128(%row: index) -> vector<[1]xi128> {
+ // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[1]xi128>, vector<[1]xi1>, i32) -> vector<[1]xi128>
+ %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
%slice = vector.extract %tile[%row] : vector<[1]xi128> from vector<[1]x[1]xi128>
return %slice : vector<[1]xi128>
}
@@ -792,8 +771,9 @@ func.func @vector_extract_slice_i128(%tile: vector<[1]x[1]xi128>, %row: index) -
// -----
// CHECK-LABEL: @vector_extract_slice_f16
-func.func @vector_extract_slice_f16(%tile: vector<[8]x[8]xf16>, %row: index) -> vector<[8]xf16> {
- // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
+func.func @vector_extract_slice_f16(%row: index) -> vector<[8]xf16> {
+ // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xf16>, vector<[8]xi1>, i32) -> vector<[8]xf16>
+ %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
%slice = vector.extract %tile[%row] : vector<[8]xf16> from vector<[8]x[8]xf16>
return %slice : vector<[8]xf16>
}
@@ -801,8 +781,9 @@ func.func @vector_extract_slice_f16(%tile: vector<[8]x[8]xf16>, %row: index) ->
// -----
// CHECK-LABEL: @vector_extract_slice_bf16
-func.func @vector_extract_slice_bf16(%tile: vector<[8]x[8]xbf16>, %row: index) -> vector<[8]xbf16> {
- // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
+func.func @vector_extract_slice_bf16(%row: index) -> vector<[8]xbf16> {
+ // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xbf16>, vector<[8]xi1>, i32) -> vector<[8]xbf16>
+ %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
%slice = vector.extract %tile[%row] : vector<[8]xbf16> from vector<[8]x[8]xbf16>
return %slice : vector<[8]xbf16>
}
@@ -810,8 +791,9 @@ func.func @vector_extract_slice_bf16(%tile: vector<[8]x[8]xbf16>, %row: index) -
// -----
// CHECK-LABEL: @vector_extract_slice_f32
-func.func @vector_extract_slice_f32(%tile: vector<[4]x[4]xf32>, %row: index) -> vector<[4]xf32> {
- // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+func.func @vector_extract_slice_f32(%row: index) -> vector<[4]xf32> {
+ // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xf32>, vector<[4]xi1>, i32) -> vector<[4]xf32>
+ %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
%slice = vector.extract %tile[%row] : vector<[4]xf32> from vector<[4]x[4]xf32>
return %slice : vector<[4]xf32>
}
@@ -819,8 +801,9 @@ func.func @vector_extract_slice_f32(%tile: vector<[4]x[4]xf32>, %row: index) ->
// -----
// CHECK-LABEL: @vector_extract_slice_f64
-func.func @vector_extract_slice_f64(%tile: vector<[2]x[2]xf64>, %row: index) -> vector<[2]xf64> {
- // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
+func.func @vector_extract_slice_f64(%row: index) -> vector<[2]xf64> {
+ // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xf64>, vector<[2]xi1>, i32) -> vector<[2]xf64>
+ %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
%slice = vector.extract %tile[%row] : vector<[2]xf64> from vector<[2]x[2]xf64>
return %slice : vector<[2]xf64>
}
@@ -828,16 +811,15 @@ func.func @vector_extract_slice_f64(%tile: vector<[2]x[2]xf64>, %row: index) ->
// -----
// CHECK-LABEL: @vector_extract_element(
-// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xi32>,
// CHECK-SAME: %[[ROW:.*]]: index,
// CHECK-SAME: %[[COL:.*]]: index)
-func.func @vector_extract_element(%tile: vector<[4]x[4]xi32>, %row: index, %col: index) -> i32 {
- // CHECK-NEXT: %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32>
+func.func @vector_extract_element(%row: index, %col: index) -> i32 {
// CHECK-NEXT: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
- // CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
+ // CHECK-NEXT: %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32>
// CHECK-NEXT: %[[ROW_I32:.*]] = arith.index_cast %[[ROW]] : index to i32
- // CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[TILE_ID]], %[[ROW_I32]]) : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+ // CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[ROW_I32]]) <{tile_id = 0 : i32}> : (vector<[4]xi32>, vector<[4]xi1>, i32) -> vector<[4]xi32>
// CHECK-NEXT: %[[EL:.*]] = vector.extract %[[SLICE]]{{\[}}%[[COL]]] : i32 from vector<[4]xi32>
+ %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
%el = vector.extract %tile[%row, %col] : i32 from vector<[4]x[4]xi32>
return %el : i32
}
@@ -845,9 +827,10 @@ func.func @vector_extract_element(%tile: vector<[4]x[4]xi32>, %row: index, %col:
// -----
// CHECK-LABEL: @vector_extract_element_i8
-func.func @vector_extract_element_i8(%tile: vector<[16]x[16]xi8>, %row: index, %col: index) -> i8 {
- // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
+func.func @vector_extract_element_i8(%row: index, %col: index) -> i8 {
+ // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi8>, vector<[16]xi1>, i32) -> vector<[16]xi8>
// CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i8 from vector<[16]xi8>
+ %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
%el = vector.extract %tile[%row, %col] : i8 from vector<[16]x[16]xi8>
return %el : i8
}
@@ -855,9 +838,10 @@ func.func @vector_extract_element_i8(%tile: vector<[16]x[16]xi8>, %row: index, %
// -----
// CHECK-LABEL: @vector_extract_element_i16
-func.func @vector_extract_element_i16(%tile: vector<[8]x[8]xi16>, %row: index, %col: index) -> i16 {
- // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
+func.func @vector_extract_element_i16(%row: index, %col: index) -> i16 {
+ // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi16>, vector<[8]xi1>, i32) -> vector<[8]xi16>
// CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i16 from vector<[8]xi16>
+ %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
%el = vector.extract %tile[%row, %col] : i16 from vector<[8]x[8]xi16>
return %el : i16
}
@@ -865,9 +849,10 @@ func.func @vector_extract_element_i16(%tile: vector<[8]x[8]xi16>, %row: index, %
// -----
// CHECK-LABEL: @vector_extract_element_i64
-func.func @vector_extract_element_i64(%tile: vector<[2]x[2]xi64>, %row: index, %col: index) -> i64 {
- // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
+func.func @vector_extract_element_i64(%row: index, %col: index) -> i64 {
+ // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi64>, vector<[2]xi1>, i32) -> vector<[2]xi64>
// CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i64 from vector<[2]xi64>
+ %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
%el = vector.extract %tile[%row, %col] : i64 from vector<[2]x[2]xi64>
return %el : i64
}
@@ -875,9 +860,10 @@ func.func @vector_extract_element_i64(%tile: vector<[2]x[2]xi64>, %row: index, %
// -----
// CHECK-LABEL: @vector_extract_element_i128
-func.func @vector_extract_element_i128(%tile: vector<[1]x[1]xi128>, %row: index, %col: index) -> i128 {
- // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+func.func @vector_extract_element_i128(%row: index, %col: index) -> i128 {
+ // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[1]xi128>, vector<[1]xi1>, i32) -> vector<[1]xi128>
// CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i128 from vector<[1]xi128>
+ %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
%el = vector.extract %tile[%row, %col] : i128 from vector<[1]x[1]xi128>
return %el : i128
}
@@ -885,9 +871,10 @@ func.func @vector_extract_element_i128(%tile: vector<[1]x[1]xi128>, %row: index,
// -----
// CHECK-LABEL: @vector_extract_element_f16
-func.func @vector_extract_element_f16(%tile: vector<[8]x[8]xf16>, %row: index, %col: index) -> f16 {
- // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
+func.func @vector_extract_element_f16(%row: index, %col: index) -> f16 {
+ // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xf16>, vector<[8]xi1>, i32) -> vector<[8]xf16>
// CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : f16 from vector<[8]xf16>
+ %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
%el = vector.extract %tile[%row, %col] : f16 from vector<[8]x[8]xf16>
return %el : f16
}
@@ -895,9 +882,10 @@ func.func @vector_extract_element_f16(%tile: vector<[8]x[8]xf16>, %row: index, %
// -----
// CHECK-LABEL: @vector_extract_element_bf16
-func.func @vector_extract_element_bf16(%tile: vector<[8]x[8]xbf16>, %row: index, %col: index) -> bf16 {
- // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
+func.func @vector_extract_element_bf16(%row: index, %col: index) -> bf16 {
+ // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xbf16>, vector<[8]xi1>, i32) -> vector<[8]xbf16>
// CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : bf16 from vector<[8]xbf16>
+ %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
%el = vector.extract %tile[%row, %col] : bf16 from vector<[8]x[8]xbf16>
return %el : bf16
}
@@ -905,9 +893,10 @@ func.func @vector_extract_element_bf16(%tile: vector<[8]x[8]xbf16>, %row: index,
// -----
// CHECK-LABEL: @vector_extract_element_f32
-func.func @vector_extract_element_f32(%tile: vector<[4]x[4]xf32>, %row: index, %col: index) -> f32 {
- // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+func.func @vector_extract_element_f32(%row: index, %col: index) -> f32 {
+ // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xf32>, vector<[4]xi1>, i32) -> vector<[4]xf32>
// CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : f32 from vector<[4]xf32>
+ %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
%el = vector.extract %tile[%row, %col] : f32 from vector<[4]x[4]xf32>
return %el : f32
}
@@ -915,9 +904,10 @@ func.func @vector_extract_element_f32(%tile: vector<[4]x[4]xf32>, %row: index, %
// -----
// CHECK-LABEL: @vector_extract_element_f64
-func.func @vector_extract_element_f64(%tile: vector<[2]x[2]xf64>, %row: index, %col: index) -> f64 {
- // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
+func.func @vector_extract_element_f64(%row: index, %col: index) -> f64 {
+ // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xf64>, vector<[2]xi1>, i32) -> vector<[2]xf64>
// CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : f64 from vector<[2]xf64>
+ %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
%el = vector.extract %tile[%row, %col] : f64 from vector<[2]x[2]xf64>
return %el : f64
}
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index ae3b260da83a2c1..2491b2e2468cdaf 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -452,8 +452,7 @@ func.func @transfer_write_2d__out_of_bounds(%vector : vector<[4]x[4]xf32>, %dest
// CHECK: %[[C4:.*]] = arith.constant 4 : index
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[SRC_1D:.*]] = vector.broadcast %[[SRC]] : i32 to vector<[4]xi32>
-// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
-// CHECK: %[[TILE:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
+// CHECK: %[[TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
// CHECK: %[[VSCALE:.*]] = vector.vscale
// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
// CHECK: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
@@ -500,8 +499,7 @@ func.func @broadcast_vec2d_from_vec1d(%arg0: vector<[8]xi16>) {
// CHECK-LABEL: func.func @splat_vec2d_from_i32(
// CHECK-SAME: %[[SRC:.*]]: i32) {
// CHECK: %[[BCST:.*]] = vector.broadcast %[[SRC]] : i32 to vector<[4]xi32>
-// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
-// CHECK: arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
+// CHECK: arm_sme.get_tile : vector<[4]x[4]xi32>
// CHECK: %[[VSCALE:.*]] = vector.vscale
// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %{{.*}} : index
// CHECK: scf.for {{.*}} to %[[NUM_TILE_SLICES]] {{.*}} {
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
index 18b95cf2fdf843c..ccd984bcb036e82 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
@@ -4,9 +4,9 @@
// RUN: -lower-vector-mask \
// RUN: -one-shot-bufferize="bufferize-function-boundaries" \
// RUN: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// RUN: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// RUN: -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
// RUN: -convert-arm-sme-to-llvm -cse -canonicalize \
-// RUN: -allocate-arm-sme-tiles -test-lower-to-llvm | \
+// RUN: -test-lower-to-llvm | \
// RUN: %mcr_aarch64_cmd \
// RUN: -e=entry -entry-point-result=void \
// RUN: -march=aarch64 -mattr="+sve,+sme" \
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir
index f189fd97d66cde7..81f503bec91efcc 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir
@@ -3,10 +3,10 @@
// RUN: -one-shot-bufferize="bufferize-function-boundaries" -canonicalize \
// RUN: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
// RUN: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
-// RUN: -convert-vector-to-scf -cse -arm-sve-legalize-vector-storage \
+// RUN: -convert-vector-to-scf -allocate-arm-sme-tiles -cse -arm-sve-legalize-vector-storage \
// RUN: -convert-arm-sme-to-llvm \
// RUN: -convert-vector-to-llvm=enable-arm-sve \
-// RUN: -cse -canonicalize -allocate-arm-sme-tiles -test-lower-to-llvm | \
+// RUN: -cse -canonicalize -test-lower-to-llvm | \
// RUN: %mcr_aarch64_cmd \
// RUN: -e=main -entry-point-result=void \
// RUN: -march=aarch64 -mattr="+sve,+sme" \
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir
index 413cb06b0ae1894..2c9edbf560b0110 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir
@@ -2,11 +2,11 @@
// RUN: -transform-interpreter -test-transform-dialect-erase-schedule \
// RUN: -canonicalize \
// RUN: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// RUN: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// RUN: -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
// RUN: -convert-vector-to-scf -cse -arm-sve-legalize-vector-storage \
// RUN: -convert-arm-sme-to-llvm \
// RUN: -convert-vector-to-llvm=enable-arm-sve \
-// RUN: -cse -canonicalize -allocate-arm-sme-tiles -test-lower-to-llvm | \
+// RUN: -cse -canonicalize -test-lower-to-llvm | \
// RUN: %mcr_aarch64_cmd \
// RUN: -e=main -entry-point-result=void \
// RUN: -march=aarch64 -mattr="+sve,+sme" \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
index 0c186cc373a3b32..936163d1cd30d99 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
@@ -1,9 +1,9 @@
// DEFINE: %{entry_point} = entry
// DEFINE: %{compile} = mlir-opt %s \
// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \
-// DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE: -test-lower-to-llvm
// DEFINE: %{run} = %mcr_aarch64_cmd \
// DEFINE: -march=aarch64 -mattr=+sve,+sme \
// DEFINE: -e %{entry_point} -entry-point-result=void \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
index 442a70cacd66508..8c73c24d695cfb2 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
@@ -1,9 +1,9 @@
// DEFINE: %{entry_point} = test_outerproduct_no_accumulator_4x4xf32
// DEFINE: %{compile} = mlir-opt %s \
// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \
-// DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm -o %t
+// DEFINE: -test-lower-to-llvm -o %t
// DEFINE: %{run} = %mcr_aarch64_cmd %t \
// DEFINE: -march=aarch64 -mattr=+sve,+sme \
// DEFINE: -e %{entry_point} -entry-point-result=void \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
index 74b51dcc9b4df3a..965337c60b9ffdd 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
@@ -1,9 +1,9 @@
// DEFINE: %{entry_point} = test_outerproduct_no_accumulator_2x2xf64
// DEFINE: %{compile} = mlir-opt %s \
// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \
-// DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm -o %t
+// DEFINE: -test-lower-to-llvm -o %t
// DEFINE: %{run} = %mcr_aarch64_cmd %t \
// DEFINE: -march=aarch64 -mattr=+sve,+sme-f64f64 \
// DEFINE: -e %{entry_point} -entry-point-result=void \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
index 82f38b4dbfa9d1f..4ca61a089bdf52d 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
@@ -1,9 +1,9 @@
// DEFINE: %{entry_point} = entry
// DEFINE: %{compile} = mlir-opt %s \
// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \
-// DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE: -test-lower-to-llvm
// DEFINE: %{run} = %mcr_aarch64_cmd \
// DEFINE: -march=aarch64 -mattr=+sve,+sme \
// DEFINE: -e %{entry_point} -entry-point-result=void \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
index 3b218aefcd415ff..14dca2d4d7082aa 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
@@ -1,9 +1,9 @@
// DEFINE: %{entry_point} = entry
// DEFINE: %{compile} = mlir-opt %s \
// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \
-// DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE: -test-lower-to-llvm
// DEFINE: %{run} = %mcr_aarch64_cmd \
// DEFINE: -march=aarch64 -mattr=+sve,+sme \
// DEFINE: -e %{entry_point} -entry-point-result=void \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
index e2cbe735fa4ff06..2751c2d136485e2 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
@@ -1,9 +1,9 @@
// DEFINE: %{entry_point} = entry
// DEFINE: %{compile} = mlir-opt %s \
// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \
-// DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE: -test-lower-to-llvm
// DEFINE: %{run} = %mcr_aarch64_cmd \
// DEFINE: -march=aarch64 -mattr=+sve,+sme \
// DEFINE: -e %{entry_point} -entry-point-result=void \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
index 6e33a421bf799a2..b45f24f6c8fddaf 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt %s -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// RUN: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// RUN: -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
// RUN: -convert-arm-sme-to-llvm -cse -canonicalize \
-// RUN: -allocate-arm-sme-tiles -test-lower-to-llvm | \
+// RUN: -test-lower-to-llvm | \
// RUN: %mcr_aarch64_cmd \
// RUN: -march=aarch64 -mattr=+sve,+sme \
// RUN: -e entry -entry-point-result=i32 \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
index 961bb274d1e3352..5d2d0a73992f1e0 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
@@ -1,9 +1,9 @@
// DEFINE: %{entry_point} = za0_d_f64
// DEFINE: %{compile} = mlir-opt %s \
// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \
-// DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE: -test-lower-to-llvm
// DEFINE: %{run} = %mcr_aarch64_cmd \
// DEFINE: -march=aarch64 -mattr=+sve,+sme \
// DEFINE: -e %{entry_point} -entry-point-result=i32 \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
index 25ef1799e63adb1..af5fe236e5bd577 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
@@ -1,8 +1,7 @@
// DEFINE: %{entry_point} = entry
// DEFINE: %{compile} = mlir-opt %s -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
-// DEFINE: -convert-arm-sme-to-llvm \
-// DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
+// DEFINE: -convert-arm-sme-to-llvm -test-lower-to-llvm
// DEFINE: %{run} = %mcr_aarch64_cmd \
// DEFINE: -march=aarch64 -mattr=+sve,+sme \
// DEFINE: -e %{entry_point} -entry-point-result=i32 \
diff --git a/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
index b3202b26f8e1e3d..7c9976bed912734 100644
--- a/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
@@ -4,10 +4,9 @@
llvm.func @arm_sme_vector_to_tile_invalid_types(%tileslice : i32,
%nxv4i1 : vector<[4]xi1>,
%nxv16i8 : vector<[16]xi8>) {
- %tile = llvm.mlir.constant(0 : index) : i32
// expected-error @+1 {{failed to verify that all of {predicate, vector} have same shape}}
- "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv4i1, %nxv16i8) :
- (i32, i32, vector<[4]xi1>, vector<[16]xi8>) -> ()
+ "arm_sme.intr.write.horiz"(%tileslice, %nxv4i1, %nxv16i8) <{tile_id = 0 : i32}> :
+ (i32, vector<[4]xi1>, vector<[16]xi8>) -> ()
llvm.return
}
@@ -16,10 +15,9 @@ llvm.func @arm_sme_vector_to_tile_invalid_types(%tileslice : i32,
llvm.func @arm_sme_tile_slice_to_vector_invalid_shapes(
%tileslice : i32, %nxv4i1 : vector<[4]xi1>, %nxv16i8 : vector<[16]xi8>
) -> vector<[3]xf32> {
- %tile = llvm.mlir.constant(0 : index) : i32
// expected-error @+1 {{failed to verify that all of {vector, predicate, res} have same shape}}
- %res = "arm_sme.intr.read.horiz"(%nxv16i8, %nxv4i1, %tile, %tileslice) :
- (vector<[16]xi8>, vector<[4]xi1>, i32, i32) -> vector<[3]xf32>
+ %res = "arm_sme.intr.read.horiz"(%nxv16i8, %nxv4i1, %tileslice) <{tile_id = 0 : i32}> :
+ (vector<[16]xi8>, vector<[4]xi1>, i32) -> vector<[3]xf32>
llvm.return %res : vector<[3]xf32>
}
@@ -28,9 +26,8 @@ llvm.func @arm_sme_tile_slice_to_vector_invalid_shapes(
llvm.func @arm_sme_tile_slice_to_vector_invalid_element_types(
%tileslice : i32, %nxv4i1 : vector<[4]xi1>, %nxv4f32 : vector<[4]xf32>
) -> vector<[3]xi32> {
- %tile = llvm.mlir.constant(0 : index) : i32
// expected-error @+1 {{failed to verify that all of {vector, res} have same element type}}
- %res = "arm_sme.intr.read.horiz"(%nxv4f32, %nxv4i1, %tile, %tileslice) :
- (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+ %res = "arm_sme.intr.read.horiz"(%nxv4f32, %nxv4i1, %tileslice) <{tile_id = 0 : i32}> :
+ (vector<[4]xf32>, vector<[4]xi1>, i32) -> vector<[4]xi32>
llvm.return %res : vector<[4]xi32>
}
diff --git a/mlir/test/Target/LLVMIR/arm-sme.mlir b/mlir/test/Target/LLVMIR/arm-sme.mlir
index 767d89a75eec326..edc1f749130440c 100644
--- a/mlir/test/Target/LLVMIR/arm-sme.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sme.mlir
@@ -2,9 +2,8 @@
// CHECK-LABEL: @arm_sme_zero
llvm.func @arm_sme_zero() {
- %c0 = llvm.mlir.constant(0 : index) : i32
// CHECK: call void @llvm.aarch64.sme.zero(i32 0)
- "arm_sme.intr.zero"(%c0) : (i32) -> ()
+ "arm_sme.intr.zero"() <{tile_mask = 0 : i32}> : () -> ()
llvm.return
}
@@ -18,19 +17,18 @@ llvm.func @arm_sme_fmopa(%nxv2f64 : vector<[2]xf64>,
%nxv2i1 : vector<[2]xi1>,
%nxv4i1 : vector<[4]xi1>,
%nxv8i1 : vector<[8]xi1>) {
- %c0 = llvm.mlir.constant(0 : index) : i32
// CHECK: call void @llvm.aarch64.sme.mopa.nxv2f64
- "arm_sme.intr.mopa"(%c0, %nxv2i1, %nxv2i1, %nxv2f64, %nxv2f64) :
- (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>) -> ()
+ "arm_sme.intr.mopa"(%nxv2i1, %nxv2i1, %nxv2f64, %nxv2f64) <{tile_id = 0 : i32}> :
+ (vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>) -> ()
// CHECK: call void @llvm.aarch64.sme.mopa.nxv4f32
- "arm_sme.intr.mopa"(%c0, %nxv4i1, %nxv4i1, %nxv4f32, %nxv4f32) :
- (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>) -> ()
+ "arm_sme.intr.mopa"(%nxv4i1, %nxv4i1, %nxv4f32, %nxv4f32) <{tile_id = 0 : i32}> :
+ (vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>) -> ()
// CHECK: call void @llvm.aarch64.sme.mopa.wide.nxv8f16
- "arm_sme.intr.mopa.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8f16, %nxv8f16) :
- (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>) -> ()
+ "arm_sme.intr.mopa.wide"(%nxv8i1, %nxv8i1, %nxv8f16, %nxv8f16) <{tile_id = 0 : i32}> :
+ (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>) -> ()
// CHECK: call void @llvm.aarch64.sme.mopa.wide.nxv8bf16
- "arm_sme.intr.mopa.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8bf16, %nxv8bf16) :
- (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>) -> ()
+ "arm_sme.intr.mopa.wide"(%nxv8i1, %nxv8i1, %nxv8bf16, %nxv8bf16) <{tile_id = 0 : i32}> :
+ (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>) -> ()
llvm.return
}
@@ -41,31 +39,30 @@ llvm.func @arm_sme_imopa(%nxv8i16 : vector<[8]xi16>,
%nxv16i8 : vector<[16]xi8>,
%nxv8i1 : vector<[8]xi1>,
%nxv16i1 : vector<[16]xi1>) {
- %c0 = llvm.mlir.constant(0 : index) : i32
// CHECK: call void @llvm.aarch64.sme.smopa.wide.nxv8i16
- "arm_sme.intr.smopa.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
- (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+ "arm_sme.intr.smopa.wide"(%nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) <{tile_id = 0 : i32}> :
+ (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
// CHECK: call void @llvm.aarch64.sme.umopa.wide.nxv8i16
- "arm_sme.intr.umopa.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
- (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+ "arm_sme.intr.umopa.wide"(%nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) <{tile_id = 0 : i32}> :
+ (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
// CHECK: call void @llvm.aarch64.sme.sumopa.wide.nxv8i16
- "arm_sme.intr.sumopa.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
- (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+ "arm_sme.intr.sumopa.wide"(%nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) <{tile_id = 0 : i32}> :
+ (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
// CHECK: call void @llvm.aarch64.sme.usmopa.wide.nxv8i16
- "arm_sme.intr.usmopa.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
- (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+ "arm_sme.intr.usmopa.wide"(%nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) <{tile_id = 0 : i32}> :
+ (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
// CHECK: call void @llvm.aarch64.sme.smopa.wide.nxv16i8
- "arm_sme.intr.smopa.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
- (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+ "arm_sme.intr.smopa.wide"(%nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) <{tile_id = 0 : i32}> :
+ (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
// CHECK: call void @llvm.aarch64.sme.umopa.wide.nxv16i8
- "arm_sme.intr.umopa.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
- (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+ "arm_sme.intr.umopa.wide"(%nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) <{tile_id = 0 : i32}> :
+ (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
// CHECK: call void @llvm.aarch64.sme.sumopa.wide.nxv16i8
- "arm_sme.intr.sumopa.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
- (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+ "arm_sme.intr.sumopa.wide"(%nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) <{tile_id = 0 : i32}> :
+ (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
// CHECK: call void @llvm.aarch64.sme.usmopa.wide.nxv16i8
- "arm_sme.intr.usmopa.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
- (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+ "arm_sme.intr.usmopa.wide"(%nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) <{tile_id = 0 : i32}> :
+ (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
llvm.return
}
@@ -79,19 +76,18 @@ llvm.func @arm_sme_fmops(%nxv2f64 : vector<[2]xf64>,
%nxv2i1 : vector<[2]xi1>,
%nxv4i1 : vector<[4]xi1>,
%nxv8i1 : vector<[8]xi1>) {
- %c0 = llvm.mlir.constant(0 : index) : i32
// CHECK: call void @llvm.aarch64.sme.mops.nxv2f64
- "arm_sme.intr.mops"(%c0, %nxv2i1, %nxv2i1, %nxv2f64, %nxv2f64) :
- (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>) -> ()
+ "arm_sme.intr.mops"(%nxv2i1, %nxv2i1, %nxv2f64, %nxv2f64) <{tile_id = 0 : i32}> :
+ (vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>) -> ()
// CHECK: call void @llvm.aarch64.sme.mops.nxv4f32
- "arm_sme.intr.mops"(%c0, %nxv4i1, %nxv4i1, %nxv4f32, %nxv4f32) :
- (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>) -> ()
+ "arm_sme.intr.mops"(%nxv4i1, %nxv4i1, %nxv4f32, %nxv4f32) <{tile_id = 0 : i32}> :
+ (vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>) -> ()
// CHECK: call void @llvm.aarch64.sme.mops.wide.nxv8f16
- "arm_sme.intr.mops.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8f16, %nxv8f16) :
- (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>) -> ()
+ "arm_sme.intr.mops.wide"(%nxv8i1, %nxv8i1, %nxv8f16, %nxv8f16) <{tile_id = 0 : i32}> :
+ (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>) -> ()
// CHECK: call void @llvm.aarch64.sme.mops.wide.nxv8bf16
- "arm_sme.intr.mops.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8bf16, %nxv8bf16) :
- (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>) -> ()
+ "arm_sme.intr.mops.wide"(%nxv8i1, %nxv8i1, %nxv8bf16, %nxv8bf16) <{tile_id = 0 : i32}> :
+ (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>) -> ()
llvm.return
}
@@ -102,31 +98,30 @@ llvm.func @arm_sme_imops(%nxv8i16 : vector<[8]xi16>,
%nxv16i8 : vector<[16]xi8>,
%nxv8i1 : vector<[8]xi1>,
%nxv16i1 : vector<[16]xi1>) {
- %c0 = llvm.mlir.constant(0 : index) : i32
// CHECK: call void @llvm.aarch64.sme.smops.wide.nxv8i16
- "arm_sme.intr.smops.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
- (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+ "arm_sme.intr.smops.wide"(%nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) <{tile_id = 0 : i32}> :
+ (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
// CHECK: call void @llvm.aarch64.sme.umops.wide.nxv8i16
- "arm_sme.intr.umops.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
- (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+ "arm_sme.intr.umops.wide"(%nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) <{tile_id = 0 : i32}> :
+ (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
// CHECK: call void @llvm.aarch64.sme.sumops.wide.nxv8i16
- "arm_sme.intr.sumops.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
- (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+ "arm_sme.intr.sumops.wide"(%nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) <{tile_id = 0 : i32}> :
+ (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
// CHECK: call void @llvm.aarch64.sme.usmops.wide.nxv8i16
- "arm_sme.intr.usmops.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
- (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+ "arm_sme.intr.usmops.wide"(%nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) <{tile_id = 0 : i32}> :
+ (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
// CHECK: call void @llvm.aarch64.sme.smops.wide.nxv16i8
- "arm_sme.intr.smops.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
- (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+ "arm_sme.intr.smops.wide"(%nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) <{tile_id = 0 : i32}> :
+ (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
// CHECK: call void @llvm.aarch64.sme.umops.wide.nxv16i8
- "arm_sme.intr.umops.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
- (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+ "arm_sme.intr.umops.wide"(%nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) <{tile_id = 0 : i32}> :
+ (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
// CHECK: call void @llvm.aarch64.sme.sumops.wide.nxv16i8
- "arm_sme.intr.sumops.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
- (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+ "arm_sme.intr.sumops.wide"(%nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) <{tile_id = 0 : i32}> :
+ (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
// CHECK: call void @llvm.aarch64.sme.usmops.wide.nxv16i8
- "arm_sme.intr.usmops.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
- (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+ "arm_sme.intr.usmops.wide"(%nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) <{tile_id = 0 : i32}> :
+ (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
llvm.return
}
@@ -141,35 +136,35 @@ llvm.func @arm_sme_load(%nxv1i1 : vector<[1]xi1>,
%ptr : !llvm.ptr) {
%c0 = llvm.mlir.constant(0 : index) : i32
// CHECK: call void @llvm.aarch64.sme.ld1q.horiz
- "arm_sme.intr.ld1q.horiz"(%nxv1i1, %ptr, %c0, %c0) :
- (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
+ "arm_sme.intr.ld1q.horiz"(%nxv1i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+ (vector<[1]xi1>, !llvm.ptr, i32) -> ()
// CHECK: call void @llvm.aarch64.sme.ld1d.horiz
- "arm_sme.intr.ld1d.horiz"(%nxv2i1, %ptr, %c0, %c0) :
- (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
+ "arm_sme.intr.ld1d.horiz"(%nxv2i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+ (vector<[2]xi1>, !llvm.ptr, i32) -> ()
// CHECK: call void @llvm.aarch64.sme.ld1w.horiz
- "arm_sme.intr.ld1w.horiz"(%nxv4i1, %ptr, %c0, %c0) :
- (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
+ "arm_sme.intr.ld1w.horiz"(%nxv4i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+ (vector<[4]xi1>, !llvm.ptr, i32) -> ()
// CHECK: call void @llvm.aarch64.sme.ld1h.horiz
- "arm_sme.intr.ld1h.horiz"(%nxv8i1, %ptr, %c0, %c0) :
- (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
+ "arm_sme.intr.ld1h.horiz"(%nxv8i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+ (vector<[8]xi1>, !llvm.ptr, i32) -> ()
// CHECK: call void @llvm.aarch64.sme.ld1b.horiz
- "arm_sme.intr.ld1b.horiz"(%nxv16i1, %ptr, %c0, %c0) :
- (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+ "arm_sme.intr.ld1b.horiz"(%nxv16i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+ (vector<[16]xi1>, !llvm.ptr, i32) -> ()
// CHECK: call void @llvm.aarch64.sme.ld1q.vert
- "arm_sme.intr.ld1q.vert"(%nxv1i1, %ptr, %c0, %c0) :
- (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
+ "arm_sme.intr.ld1q.vert"(%nxv1i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+ (vector<[1]xi1>, !llvm.ptr, i32) -> ()
// CHECK: call void @llvm.aarch64.sme.ld1d.vert
- "arm_sme.intr.ld1d.vert"(%nxv2i1, %ptr, %c0, %c0) :
- (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
+ "arm_sme.intr.ld1d.vert"(%nxv2i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+ (vector<[2]xi1>, !llvm.ptr, i32) -> ()
// CHECK: call void @llvm.aarch64.sme.ld1w.vert
- "arm_sme.intr.ld1w.vert"(%nxv4i1, %ptr, %c0, %c0) :
- (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
+ "arm_sme.intr.ld1w.vert"(%nxv4i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+ (vector<[4]xi1>, !llvm.ptr, i32) -> ()
// CHECK: call void @llvm.aarch64.sme.ld1h.vert
- "arm_sme.intr.ld1h.vert"(%nxv8i1, %ptr, %c0, %c0) :
- (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
+ "arm_sme.intr.ld1h.vert"(%nxv8i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+ (vector<[8]xi1>, !llvm.ptr, i32) -> ()
// CHECK: call void @llvm.aarch64.sme.ld1b.vert
- "arm_sme.intr.ld1b.vert"(%nxv16i1, %ptr, %c0, %c0) :
- (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+ "arm_sme.intr.ld1b.vert"(%nxv16i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+ (vector<[16]xi1>, !llvm.ptr, i32) -> ()
llvm.return
}
@@ -184,35 +179,35 @@ llvm.func @arm_sme_store(%nxv1i1 : vector<[1]xi1>,
%ptr : !llvm.ptr) {
%c0 = llvm.mlir.constant(0 : index) : i32
// CHECK: call void @llvm.aarch64.sme.st1q.horiz
- "arm_sme.intr.st1q.horiz"(%nxv1i1, %ptr, %c0, %c0) :
- (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
+ "arm_sme.intr.st1q.horiz"(%nxv1i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+ (vector<[1]xi1>, !llvm.ptr, i32) -> ()
// CHECK: call void @llvm.aarch64.sme.st1d.horiz
- "arm_sme.intr.st1d.horiz"(%nxv2i1, %ptr, %c0, %c0) :
- (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
+ "arm_sme.intr.st1d.horiz"(%nxv2i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+ (vector<[2]xi1>, !llvm.ptr, i32) -> ()
// CHECK: call void @llvm.aarch64.sme.st1w.horiz
- "arm_sme.intr.st1w.horiz"(%nxv4i1, %ptr, %c0, %c0) :
- (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
+ "arm_sme.intr.st1w.horiz"(%nxv4i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+ (vector<[4]xi1>, !llvm.ptr, i32) -> ()
// CHECK: call void @llvm.aarch64.sme.st1h.horiz
- "arm_sme.intr.st1h.horiz"(%nxv8i1, %ptr, %c0, %c0) :
- (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
+ "arm_sme.intr.st1h.horiz"(%nxv8i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+ (vector<[8]xi1>, !llvm.ptr, i32) -> ()
// CHECK: call void @llvm.aarch64.sme.st1b.horiz
- "arm_sme.intr.st1b.horiz"(%nxv16i1, %ptr, %c0, %c0) :
- (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+ "arm_sme.intr.st1b.horiz"(%nxv16i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+ (vector<[16]xi1>, !llvm.ptr, i32) -> ()
// CHECK: call void @llvm.aarch64.sme.st1q.vert
- "arm_sme.intr.st1q.vert"(%nxv1i1, %ptr, %c0, %c0) :
- (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
+ "arm_sme.intr.st1q.vert"(%nxv1i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+ (vector<[1]xi1>, !llvm.ptr, i32) -> ()
// CHECK: call void @llvm.aarch64.sme.st1d.vert
- "arm_sme.intr.st1d.vert"(%nxv2i1, %ptr, %c0, %c0) :
- (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
+ "arm_sme.intr.st1d.vert"(%nxv2i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+ (vector<[2]xi1>, !llvm.ptr, i32) -> ()
// CHECK: call void @llvm.aarch64.sme.st1w.vert
- "arm_sme.intr.st1w.vert"(%nxv4i1, %ptr, %c0, %c0) :
- (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
+ "arm_sme.intr.st1w.vert"(%nxv4i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+ (vector<[4]xi1>, !llvm.ptr, i32) -> ()
// CHECK: call void @llvm.aarch64.sme.st1h.vert
- "arm_sme.intr.st1h.vert"(%nxv8i1, %ptr, %c0, %c0) :
- (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
+ "arm_sme.intr.st1h.vert"(%nxv8i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+ (vector<[8]xi1>, !llvm.ptr, i32) -> ()
// CHECK: call void @llvm.aarch64.sme.st1b.vert
- "arm_sme.intr.st1b.vert"(%nxv16i1, %ptr, %c0, %c0) :
- (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+ "arm_sme.intr.st1b.vert"(%nxv16i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+ (vector<[16]xi1>, !llvm.ptr, i32) -> ()
// CHECK: call void @llvm.aarch64.sme.str
"arm_sme.intr.str"(%c0, %ptr, %c0) : (i32, !llvm.ptr, i32) -> ()
llvm.return
@@ -236,34 +231,33 @@ llvm.func @arm_sme_vector_to_tile_horiz(%tileslice : i32,
%nxv8bf16 : vector<[8]xbf16>,
%nxv4f32 : vector<[4]xf32>,
%nxv2f64 : vector<[2]xf64>) {
- %tile = llvm.mlir.constant(0 : index) : i32
// CHECK: call void @llvm.aarch64.sme.write.horiz.nxv16i8
- "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv16i1, %nxv16i8) :
- (i32, i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
+ "arm_sme.intr.write.horiz"(%tileslice, %nxv16i1, %nxv16i8) <{tile_id = 0 : i32}> :
+ (i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
// CHECK: call void @llvm.aarch64.sme.write.horiz.nxv8i16
- "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv8i1, %nxv8i16) :
- (i32, i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
+ "arm_sme.intr.write.horiz"(%tileslice, %nxv8i1, %nxv8i16) <{tile_id = 0 : i32}> :
+ (i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
// CHECK: call void @llvm.aarch64.sme.write.horiz.nxv4i32
- "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv4i1, %nxv4i32) :
- (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+ "arm_sme.intr.write.horiz"(%tileslice, %nxv4i1, %nxv4i32) <{tile_id = 0 : i32}> :
+ (i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
// CHECK: call void @llvm.aarch64.sme.write.horiz.nxv2i64
- "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv2i1, %nxv2i64) :
- (i32, i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
+ "arm_sme.intr.write.horiz"(%tileslice, %nxv2i1, %nxv2i64) <{tile_id = 0 : i32}> :
+ (i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
// CHECK: call void @llvm.aarch64.sme.write.horiz.nxv1i128
- "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv1i1, %nxv1i128) :
- (i32, i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
+ "arm_sme.intr.write.horiz"(%tileslice, %nxv1i1, %nxv1i128) <{tile_id = 0 : i32}> :
+ (i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
// CHECK: call void @llvm.aarch64.sme.write.horiz.nxv8f16
- "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv8i1, %nxv8f16) :
- (i32, i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
+ "arm_sme.intr.write.horiz"(%tileslice, %nxv8i1, %nxv8f16) <{tile_id = 0 : i32}> :
+ (i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
// CHECK: call void @llvm.aarch64.sme.write.horiz.nxv8bf16
- "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv8i1, %nxv8bf16) :
- (i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
+ "arm_sme.intr.write.horiz"(%tileslice, %nxv8i1, %nxv8bf16) <{tile_id = 0 : i32}> :
+ (i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
// CHECK: call void @llvm.aarch64.sme.write.horiz.nxv4f32
- "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv4i1, %nxv4f32) :
- (i32, i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
+ "arm_sme.intr.write.horiz"(%tileslice, %nxv4i1, %nxv4f32) <{tile_id = 0 : i32}> :
+ (i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
// CHECK: call void @llvm.aarch64.sme.write.horiz.nxv2f64
- "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv2i1, %nxv2f64) :
- (i32, i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
+ "arm_sme.intr.write.horiz"(%tileslice, %nxv2i1, %nxv2f64) <{tile_id = 0 : i32}> :
+ (i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
llvm.return
}
@@ -285,34 +279,33 @@ llvm.func @arm_sme_vector_to_tile_vert(%tileslice : i32,
%nxv8bf16 : vector<[8]xbf16>,
%nxv4f32 : vector<[4]xf32>,
%nxv2f64 : vector<[2]xf64>) {
- %tile = llvm.mlir.constant(0 : index) : i32
// CHECK: call void @llvm.aarch64.sme.write.vert.nxv16i8
- "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv16i1, %nxv16i8) :
- (i32, i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
+ "arm_sme.intr.write.vert"(%tileslice, %nxv16i1, %nxv16i8) <{tile_id = 0 : i32}> :
+ (i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
// CHECK: call void @llvm.aarch64.sme.write.vert.nxv8i16
- "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv8i1, %nxv8i16) :
- (i32, i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
+ "arm_sme.intr.write.vert"(%tileslice, %nxv8i1, %nxv8i16) <{tile_id = 0 : i32}> :
+ (i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
// CHECK: call void @llvm.aarch64.sme.write.vert.nxv4i32
- "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv4i1, %nxv4i32) :
- (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+ "arm_sme.intr.write.vert"(%tileslice, %nxv4i1, %nxv4i32) <{tile_id = 0 : i32}> :
+ (i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
// CHECK: call void @llvm.aarch64.sme.write.vert.nxv2i64
- "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv2i1, %nxv2i64) :
- (i32, i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
+ "arm_sme.intr.write.vert"(%tileslice, %nxv2i1, %nxv2i64) <{tile_id = 0 : i32}> :
+ (i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
// CHECK: call void @llvm.aarch64.sme.write.vert.nxv1i128
- "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv1i1, %nxv1i128) :
- (i32, i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
+ "arm_sme.intr.write.vert"(%tileslice, %nxv1i1, %nxv1i128) <{tile_id = 0 : i32}> :
+ (i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
// CHECK: call void @llvm.aarch64.sme.write.vert.nxv8f16
- "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv8i1, %nxv8f16) :
- (i32, i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
+ "arm_sme.intr.write.vert"(%tileslice, %nxv8i1, %nxv8f16) <{tile_id = 0 : i32}> :
+ (i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
// CHECK: call void @llvm.aarch64.sme.write.vert.nxv8bf16
- "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv8i1, %nxv8bf16) :
- (i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
+ "arm_sme.intr.write.vert"(%tileslice, %nxv8i1, %nxv8bf16) <{tile_id = 0 : i32}> :
+ (i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
// CHECK: call void @llvm.aarch64.sme.write.vert.nxv4f32
- "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv4i1, %nxv4f32) :
- (i32, i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
+ "arm_sme.intr.write.vert"(%tileslice, %nxv4i1, %nxv4f32) <{tile_id = 0 : i32}> :
+ (i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
// CHECK: call void @llvm.aarch64.sme.write.vert.nxv2f64
- "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv2i1, %nxv2f64) :
- (i32, i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
+ "arm_sme.intr.write.vert"(%tileslice, %nxv2i1, %nxv2f64) <{tile_id = 0 : i32}> :
+ (i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
llvm.return
}
@@ -334,34 +327,33 @@ llvm.func @arm_sme_tile_slice_to_vector_horiz(%tileslice : i32,
%nxv8bf16 : vector<[8]xbf16>,
%nxv4f32 : vector<[4]xf32>,
%nxv2f64 : vector<[2]xf64>) {
- %tile = llvm.mlir.constant(0 : index) : i32
// CHECK: call <vscale x 16 x i8> @llvm.aarch64.sme.read.horiz.nxv16i8
- %res0 = "arm_sme.intr.read.horiz"(%nxv16i8, %nxv16i1, %tile, %tileslice)
- : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
+ %res0 = "arm_sme.intr.read.horiz"(%nxv16i8, %nxv16i1, %tileslice) <{tile_id = 0 : i32}>
+ : (vector<[16]xi8>, vector<[16]xi1>, i32) -> vector<[16]xi8>
// CHECK: call <vscale x 8 x i16> @llvm.aarch64.sme.read.horiz.nxv8i16
- %res1 = "arm_sme.intr.read.horiz"(%nxv8i16, %nxv8i1, %tile, %tileslice)
- : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
+ %res1 = "arm_sme.intr.read.horiz"(%nxv8i16, %nxv8i1, %tileslice) <{tile_id = 0 : i32}>
+ : (vector<[8]xi16>, vector<[8]xi1>, i32) -> vector<[8]xi16>
// CHECK: call <vscale x 4 x i32> @llvm.aarch64.sme.read.horiz.nxv4i32
- %res2 = "arm_sme.intr.read.horiz"(%nxv4i32, %nxv4i1, %tile, %tileslice)
- : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+ %res2 = "arm_sme.intr.read.horiz"(%nxv4i32, %nxv4i1, %tileslice) <{tile_id = 0 : i32}>
+ : (vector<[4]xi32>, vector<[4]xi1>, i32) -> vector<[4]xi32>
// CHECK: call <vscale x 2 x i64> @llvm.aarch64.sme.read.horiz.nxv2i64
- %res3 = "arm_sme.intr.read.horiz"(%nxv2i64, %nxv2i1, %tile, %tileslice)
- : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
+ %res3 = "arm_sme.intr.read.horiz"(%nxv2i64, %nxv2i1, %tileslice) <{tile_id = 0 : i32}>
+ : (vector<[2]xi64>, vector<[2]xi1>, i32) -> vector<[2]xi64>
// CHECK: call <vscale x 1 x i128> @llvm.aarch64.sme.read.horiz.nxv1i128
- %res4 = "arm_sme.intr.read.horiz"(%nxv1i128, %nxv1i1, %tile, %tileslice)
- : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+ %res4 = "arm_sme.intr.read.horiz"(%nxv1i128, %nxv1i1, %tileslice) <{tile_id = 0 : i32}>
+ : (vector<[1]xi128>, vector<[1]xi1>, i32) -> vector<[1]xi128>
// CHECK: call <vscale x 8 x half> @llvm.aarch64.sme.read.horiz.nxv8f16
- %res5 = "arm_sme.intr.read.horiz"(%nxv8f16, %nxv8i1, %tile, %tileslice)
- : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
+ %res5 = "arm_sme.intr.read.horiz"(%nxv8f16, %nxv8i1, %tileslice) <{tile_id = 0 : i32}>
+ : (vector<[8]xf16>, vector<[8]xi1>, i32) -> vector<[8]xf16>
// CHECK: call <vscale x 8 x bfloat> @llvm.aarch64.sme.read.horiz.nxv8bf16
- %res6 = "arm_sme.intr.read.horiz"(%nxv8bf16, %nxv8i1, %tile, %tileslice)
- : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
+ %res6 = "arm_sme.intr.read.horiz"(%nxv8bf16, %nxv8i1, %tileslice) <{tile_id = 0 : i32}>
+ : (vector<[8]xbf16>, vector<[8]xi1>, i32) -> vector<[8]xbf16>
// CHECK: call <vscale x 4 x float> @llvm.aarch64.sme.read.horiz.nxv4f32
- %res7 = "arm_sme.intr.read.horiz"(%nxv4f32, %nxv4i1, %tile, %tileslice)
- : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+ %res7 = "arm_sme.intr.read.horiz"(%nxv4f32, %nxv4i1, %tileslice) <{tile_id = 0 : i32}>
+ : (vector<[4]xf32>, vector<[4]xi1>, i32) -> vector<[4]xf32>
// CHECK: call <vscale x 2 x double> @llvm.aarch64.sme.read.horiz.nxv2f64
- %res8 = "arm_sme.intr.read.horiz"(%nxv2f64, %nxv2i1, %tile, %tileslice)
- : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
+ %res8 = "arm_sme.intr.read.horiz"(%nxv2f64, %nxv2i1, %tileslice) <{tile_id = 0 : i32}>
+ : (vector<[2]xf64>, vector<[2]xi1>, i32) -> vector<[2]xf64>
llvm.return
}
@@ -382,33 +374,32 @@ llvm.func @arm_sme_tile_slice_to_vector_vert(%tileslice : i32,
%nxv8bf16 : vector<[8]xbf16>,
%nxv4f32 : vector<[4]xf32>,
%nxv2f64 : vector<[2]xf64>) {
- %tile = llvm.mlir.constant(0 : index) : i32
// CHECK: call <vscale x 16 x i8> @llvm.aarch64.sme.read.vert.nxv16i8
- %res0 = "arm_sme.intr.read.vert"(%nxv16i8, %nxv16i1, %tile, %tileslice)
- : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
+ %res0 = "arm_sme.intr.read.vert"(%nxv16i8, %nxv16i1, %tileslice) <{tile_id = 0 : i32}>
+ : (vector<[16]xi8>, vector<[16]xi1>, i32) -> vector<[16]xi8>
// CHECK: call <vscale x 8 x i16> @llvm.aarch64.sme.read.vert.nxv8i16
- %res1 = "arm_sme.intr.read.vert"(%nxv8i16, %nxv8i1, %tile, %tileslice)
- : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
+ %res1 = "arm_sme.intr.read.vert"(%nxv8i16, %nxv8i1, %tileslice) <{tile_id = 0 : i32}>
+ : (vector<[8]xi16>, vector<[8]xi1>, i32) -> vector<[8]xi16>
// CHECK: call <vscale x 4 x i32> @llvm.aarch64.sme.read.vert.nxv4i32
- %res2 = "arm_sme.intr.read.vert"(%nxv4i32, %nxv4i1, %tile, %tileslice)
- : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+ %res2 = "arm_sme.intr.read.vert"(%nxv4i32, %nxv4i1, %tileslice) <{tile_id = 0 : i32}>
+ : (vector<[4]xi32>, vector<[4]xi1>, i32) -> vector<[4]xi32>
// CHECK: call <vscale x 2 x i64> @llvm.aarch64.sme.read.vert.nxv2i64
- %res3 = "arm_sme.intr.read.vert"(%nxv2i64, %nxv2i1, %tile, %tileslice)
- : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
+ %res3 = "arm_sme.intr.read.vert"(%nxv2i64, %nxv2i1, %tileslice) <{tile_id = 0 : i32}>
+ : (vector<[2]xi64>, vector<[2]xi1>, i32) -> vector<[2]xi64>
// CHECK: call <vscale x 1 x i128> @llvm.aarch64.sme.read.vert.nxv1i128
- %res4 = "arm_sme.intr.read.vert"(%nxv1i128, %nxv1i1, %tile, %tileslice)
- : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+ %res4 = "arm_sme.intr.read.vert"(%nxv1i128, %nxv1i1, %tileslice) <{tile_id = 0 : i32}>
+ : (vector<[1]xi128>, vector<[1]xi1>, i32) -> vector<[1]xi128>
// CHECK: call <vscale x 8 x half> @llvm.aarch64.sme.read.vert.nxv8f16
- %res5 = "arm_sme.intr.read.vert"(%nxv8f16, %nxv8i1, %tile, %tileslice)
- : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
+ %res5 = "arm_sme.intr.read.vert"(%nxv8f16, %nxv8i1, %tileslice) <{tile_id = 0 : i32}>
+ : (vector<[8]xf16>, vector<[8]xi1>, i32) -> vector<[8]xf16>
// CHECK: call <vscale x 8 x bfloat> @llvm.aarch64.sme.read.vert.nxv8bf16
- %res6 = "arm_sme.intr.read.vert"(%nxv8bf16, %nxv8i1, %tile, %tileslice)
- : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
+ %res6 = "arm_sme.intr.read.vert"(%nxv8bf16, %nxv8i1, %tileslice) <{tile_id = 0 : i32}>
+ : (vector<[8]xbf16>, vector<[8]xi1>, i32) -> vector<[8]xbf16>
// CHECK: call <vscale x 4 x float> @llvm.aarch64.sme.read.vert.nxv4f32
- %res7 = "arm_sme.intr.read.vert"(%nxv4f32, %nxv4i1, %tile, %tileslice)
- : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+ %res7 = "arm_sme.intr.read.vert"(%nxv4f32, %nxv4i1, %tileslice) <{tile_id = 0 : i32}>
+ : (vector<[4]xf32>, vector<[4]xi1>, i32) -> vector<[4]xf32>
// CHECK: call <vscale x 2 x double> @llvm.aarch64.sme.read.vert.nxv2f64
- %res8 = "arm_sme.intr.read.vert"(%nxv2f64, %nxv2i1, %tile, %tileslice)
- : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
+ %res8 = "arm_sme.intr.read.vert"(%nxv2f64, %nxv2i1, %tileslice) <{tile_id = 0 : i32}>
+ : (vector<[2]xf64>, vector<[2]xi1>, i32) -> vector<[2]xf64>
llvm.return
}
diff --git a/mlir/tools/mlir-query/mlir-query.cpp b/mlir/tools/mlir-query/mlir-query.cpp
index 0ed4f94d5802b09..1e18d23c0044159 100644
--- a/mlir/tools/mlir-query/mlir-query.cpp
+++ b/mlir/tools/mlir-query/mlir-query.cpp
@@ -21,9 +21,9 @@
using namespace mlir;
// This is needed because these matchers are defined as overloaded functions.
-using HasOpAttrName = detail::AttrOpMatcher(StringRef);
-using HasOpName = detail::NameOpMatcher(StringRef);
-using IsConstantOp = detail::constant_op_matcher();
+using HasOpAttrName = mlir::detail::AttrOpMatcher(StringRef);
+using HasOpName = mlir::detail::NameOpMatcher(StringRef);
+using IsConstantOp = mlir::detail::constant_op_matcher();
namespace test {
#ifdef MLIR_INCLUDE_TESTS
More information about the Mlir-commits
mailing list