[Mlir-commits] [mlir] 6ff9761 - [mlir][ArmSME] Add custom get_tile_id and cast ops
Cullen Rhodes
llvmlistbot at llvm.org
Tue Jul 18 00:42:08 PDT 2023
Author: Cullen Rhodes
Date: 2023-07-18T07:41:45Z
New Revision: 6ff9761a69df24d33f7c1047a956911bf28a754b
URL: https://github.com/llvm/llvm-project/commit/6ff9761a69df24d33f7c1047a956911bf28a754b
DIFF: https://github.com/llvm/llvm-project/commit/6ff9761a69df24d33f7c1047a956911bf28a754b.diff
LOG: [mlir][ArmSME] Add custom get_tile_id and cast ops
This patch adds three new custom ops to the ArmSME dialect:
* arm_sme.get_tile_id - returns a scalar integer representing an SME
"virtual tile" that is not in use.
* arm_sme.cast_tile_to_vector - casts from a tile id to a 2-d scalable
vector type, which represents an SME "virtual tile".
* arm_sme.cast_vector_to_tile - casts from a 2-d scalable vector type,
which represents an SME "virtual tile", to a tile id.
The 'arm_sme.get_tile_id' op currently only supports tile 0, a follow-up
patch will implement proper tile allocation. A further follow-up patch
will demonstrate load/store to/from ZA using these ops.
See the op descriptions for further details and examples.
Thanks to @paulwalker-arm and @awarzynski for helping drive this.
Reviewed By: awarzynski, dcaballe
Differential Revision: https://reviews.llvm.org/D154941
Added:
mlir/test/Dialect/ArmSME/canonicalize.mlir
mlir/test/Dialect/ArmSME/invalid.mlir
mlir/test/Dialect/ArmSME/roundtrip.mlir
Modified:
mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
mlir/include/mlir/IR/OpBase.td
mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 140ed51b101b97..19283155b21714 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -36,6 +36,166 @@ def ArmSME_Dialect : Dialect {
let dependentDialects = ["scf::SCFDialect"];
}
+//===----------------------------------------------------------------------===//
+// ArmSME type definitions
+//===----------------------------------------------------------------------===//
+
+class SMETileType<Type datatype, list<int> dims, string description>
+ : ShapedContainerType<[datatype],
+ And<[IsVectorOfRankPred<[2]>, allDimsScalableVectorTypePred,
+ IsVectorOfShape<dims>]>,
+ description>;
+
+def nxnxv16i8 : SMETileType<I8, [16, 16], "vector<[16]x[16]xi8>">;
+def nxnxv8i16 : SMETileType<I16, [8, 8 ], "vector<[8]x[8]xi16>">;
+def nxnxv4i32 : SMETileType<I32, [4, 4 ], "vector<[4]x[4]xi32>">;
+def nxnxv2i64 : SMETileType<I64, [2, 2 ], "vector<[2]x[2]xi64>">;
+def nxnxv1i128 : SMETileType<I128, [1, 1 ], "vector<[1]x[1]xi128>">;
+
+def nxnxv8f16 : SMETileType<F16, [8, 8 ], "vector<[8]x[8]xf16>">;
+def nxnxv8bf16 : SMETileType<BF16, [8, 8 ], "vector<[8]x[8]xbf16>">;
+def nxnxv4f32 : SMETileType<F32, [4, 4 ], "vector<[4]x[4]xf32>">;
+def nxnxv2f64 : SMETileType<F64, [2, 2 ], "vector<[2]x[2]xf64>">;
+
+def SMETile : AnyTypeOf<[nxnxv16i8, nxnxv8i16, nxnxv4i32, nxnxv2i64, nxnxv1i128,
+ nxnxv8f16, nxnxv8bf16, nxnxv4f32, nxnxv2f64]>;
+
+// A type constraint that verifies the bitwidth of the scalar integer returned
+// from 'arm_sme.get_tile_id' matches the element bitwidth of a "virtual tile".
+def TileElementWidthMatchesTileID : TypesMatchWith<
+ "`tile_id` has the same number of bits as elements in `vector`",
+ "vector", "tile_id",
+ "IntegerType::get("
+ "$_self.getContext(),"
+ "::llvm::isa<IntegerType>(::llvm::cast<VectorType>($_self).getElementType())"
+ "? ::llvm::cast<IntegerType>("
+ "::llvm::cast<VectorType>($_self).getElementType())"
+ ".getWidth()"
+ ": ::llvm::cast<FloatType>("
+ "::llvm::cast<VectorType>($_self).getElementType())"
+ ".getWidth())">;
+
+//===----------------------------------------------------------------------===//
+// ArmSME op definitions
+//===----------------------------------------------------------------------===//
+
+class ArmSME_Op<string mnemonic, list<Trait> traits = []> :
+ Op<ArmSME_Dialect, mnemonic, traits> {}
+
+def CastTileToVector : ArmSME_Op<"cast_tile_to_vector", [Pure, TileElementWidthMatchesTileID]> {
+ let summary = "Cast from tile id to 2-d scalable vector type";
+ let description = [{
+ A `cast_tile_to_vector` operation does a cast from a tile id to a 2-d
+ scalable vector type, which represents an SME "virtual tile". This would
+ normally be used when lowering operations that return "virtual tile" vector
+ types to model the output. This is required to preserve dataflow as SME
+ intrinsics have no return values.
+
+ Example:
+
+ Input:
+ ```mlir
+ %tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+ vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+ ```
+
+ After lowering `vector.load`:
+ ```mlir
+ %tile_id = arm_sme.get_tile_id : i32
+ scf.for %vnum = %c0 to %num_vectors step %c1 {
+ // ...
+ "arm_sme.intr.ld1w.horiz"(%pg, %ptr, %tile_id, %vnum) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
+ }
+ %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
+ vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+ ```
+
+ In the example above, the `vector.load` can't be replaced with an SME
+ intrinsic that has no outputs since it is used by the `vector.store`.
+ However, by inserting a `cast_tile_to_vector` op after the load intrinsics
+ the `vector.load` can be replaced. This enables "local" rewrites on
+ individual vector ops, rather than "global" rewrites that would have to
+ look at the vector op uses and also lower them.
+
+ Canonicalization will look through `arm_sme.cast_tile_to_vector` and fold
+ the cast away if it comes from a `arm_sme.cast_vector_to_tile`.
+ }];
+ let arguments = (ins AnyTypeOf<[I8, I16, I32, I64, I128]>:$tile_id);
+ let results = (outs SMETile:$vector);
+ let assemblyFormat =
+ "$tile_id attr-dict `:` type($tile_id) `to` type($vector)";
+ let hasCanonicalizeMethod = 1;
+}
+
+def CastVectorToTile : ArmSME_Op<"cast_vector_to_tile", [Pure, TileElementWidthMatchesTileID]> {
+ let summary = "Cast from 2-d scalable vector type to tile id";
+ let description = [{
+ A `cast_vector_to_tile` operation does a cast from a 2-d scalable vector
+ type, which represents an SME "virtual tile", to a tile id. This is
+ required to preserve dataflow as the SME intrinsics have no return values.
+
+ Example:
+
+ Input:
+ ```mlir
+ %tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+ vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+ ```
+
+ After lowering `vector.store`:
+ ```mlir
+ %tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+ scf.for %vnum = %c0 to %num_vectors step %c1 {
+ // ...
+ %tile_id = arm_sme.cast_vector_to_tile %tile : (vector<[4]x[4]xi32>) -> i32
+ "arm_sme.intr.st1w.horiz"(%pg, %ptr, %tile_id, %vnum) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
+ }
+ ```
+
+ Canonicalization will look through `arm_sme.cast_vector_to_tile` and fold
+ the cast away if it comes from a `arm_sme.cast_tile_to_vector`.
+ }];
+ let arguments = (ins SMETile:$vector);
+ let results = (outs AnyTypeOf<[I8, I16, I32, I64, I128]>:$tile_id);
+ let assemblyFormat =
+ "$vector attr-dict `:` type($vector) `to` type($tile_id)";
+ let hasCanonicalizeMethod = 1;
+}
+
+def GetTileID : ArmSME_Op<"get_tile_id", [Pure]> {
+ let summary = "Returns an SME \"virtual tile\" id";
+ let description = [{
+ A `get_tile_id` operation returns a scalar integer representing an SME
+ "virtual tile" id. The bitwidth of the scalar indicates the element
+ bitwidth of the "virtual tile".
+
+ The scope of a tile id is a function and cannot be passed or returned from
+ functions.
+
+ Example:
+ ```mlir
+ // Allocate and return an 8-bit element "virtual tile" id
+ %za0_b = arm_sme.get_tile_id : i8
+ ```
+
+ Example:
+ ```
+ // Allocate and return two 16-bit element "virtual tile" ids
+ %za0_h = arm_sme.get_tile_id : i16
+ %za1_h = arm_sme.get_tile_id : i16
+ ```
+
+ Example:
+ ```
+ // Allocate and return an 128-bit element "virtual tile" id
+ %za0_q = arm_sme.get_tile_id : i128
+ ```
+ }];
+
+ let results = (outs AnyTypeOf<[I8, I16, I32, I64, I128]>:$tile_id);
+ let assemblyFormat = "attr-dict `:` type($tile_id)";
+}
+
//===----------------------------------------------------------------------===//
// ArmSME Intrinsic op definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index ad10fdb4255c04..28a1716228e8f0 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -308,6 +308,12 @@ def IsFixedVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
def IsScalableVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
::llvm::cast<VectorType>($_self).isScalable()}]>;
+// Whether a type is a VectorType and all dimensions are scalable.
+def allDimsScalableVectorTypePred : And<[
+ IsVectorTypePred,
+ CPred<[{::llvm::cast<::mlir::VectorType>($_self).allDimsScalable()}]>
+]>;
+
// Whether a type is a TensorType.
def IsTensorTypePred : CPred<"::llvm::isa<::mlir::TensorType>($_self)">;
@@ -488,6 +494,7 @@ def I8 : I<8>;
def I16 : I<16>;
def I32 : I<32>;
def I64 : I<64>;
+def I128 : I<128>;
// Any signed integer type irrespective of its width.
def AnySignedInteger : Type<
@@ -745,6 +752,10 @@ class IsScalableVectorOfLengthPred<list<int> allowedLengths> :
== }]
# allowedlength>)>]>;
+// Whether the shape of a vector matches the given `shape` list.
+class IsVectorOfShape<list<int> shape>
+ : CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>({" # !interleave(shape, ", ") # "})">;
+
// Any vector where the number of elements is from the given
// `allowedLengths` list
class VectorOfLength<list<int> allowedLengths> : Type<
diff --git a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
index 7f5aa61aa327eb..750627421215df 100644
--- a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
@@ -34,3 +34,23 @@ void ArmSMEDialect::initialize() {
#include "mlir/Dialect/ArmSME/IR/ArmSME.cpp.inc"
>();
}
+
+// cast_vector_to_tile(cast_tile_to_vector(tile_id)) -> tile_id
+LogicalResult CastVectorToTile::canonicalize(CastVectorToTile op,
+ PatternRewriter &rewriter) {
+ if (auto castTileToVectorOp = op.getVector().getDefiningOp<CastTileToVector>()) {
+ op.replaceAllUsesWith(castTileToVectorOp.getTileId());
+ return success();
+ }
+ return failure();
+}
+
+// cast_tile_to_vector(cast_vector_to_tile(tile)) -> tile
+LogicalResult CastTileToVector::canonicalize(CastTileToVector op,
+ PatternRewriter &rewriter) {
+ if (auto castVectorToTileOp = op.getTileId().getDefiningOp<CastVectorToTile>()) {
+ op.replaceAllUsesWith(castVectorToTileOp.getVector());
+ return success();
+ }
+ return failure();
+}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index 2eb061da49f440..aa59aa5b2b3585 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -7,9 +7,11 @@
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
using namespace mlir;
@@ -43,6 +45,17 @@ struct DisableZAPattern : public OpRewritePattern<func::ReturnOp> {
return success();
}
};
+
+struct GetTileIDConversion : public ConvertOpToLLVMPattern<GetTileID> {
+ using ConvertOpToLLVMPattern<GetTileID>::ConvertOpToLLVMPattern;
+ LogicalResult
+ matchAndRewrite(GetTileID op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // TODO: implement tile allocation, currently only tile 0 is supported.
+ rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(op, rewriter.getI32Type(), 0);
+ return success();
+ }
+};
} // namespace
void mlir::populateArmSMELegalizeForLLVMExportPatterns(
@@ -52,9 +65,11 @@ void mlir::populateArmSMELegalizeForLLVMExportPatterns(
void mlir::configureArmSMELegalizeForExportTarget(
LLVMConversionTarget &target) {
- target.addLegalOp<scf::ForOp, scf::YieldOp, arm_sme::aarch64_sme_zero,
+ target.addLegalOp<scf::ForOp, scf::YieldOp, arm_sme::CastTileToVector,
+ arm_sme::CastVectorToTile, arm_sme::aarch64_sme_zero,
arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_za_enable,
arm_sme::aarch64_sme_za_disable>();
+ target.addLegalOp<GetTileID>();
// Mark 'func.func' ops as legal if either:
// 1. no 'arm_za' function attribute is present.
diff --git a/mlir/test/Dialect/ArmSME/canonicalize.mlir b/mlir/test/Dialect/ArmSME/canonicalize.mlir
new file mode 100644
index 00000000000000..06bbd3050fdece
--- /dev/null
+++ b/mlir/test/Dialect/ArmSME/canonicalize.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt -canonicalize -split-input-file -verify-diagnostics %s | mlir-opt | FileCheck %s
+
+// -----
+
+// CHECK-LABEL: @cast_vector_to_tile__cast_tile_to_vector
+// CHECK-SAME: %[[TILE_ID:.*]]: i8
+func.func @cast_vector_to_tile__cast_tile_to_vector(%tile_id_0 : i8) -> i8 {
+ // CHECK-NOT: arm_sme.cast_tile_to_vector
+ // CHECK-NOT: arm_sme.cast_vector_to_tile
+ // CHECK-NEXT: return %[[TILE_ID]] : i8
+ %tile = arm_sme.cast_tile_to_vector %tile_id_0 : i8 to vector<[16]x[16]xi8>
+ %tile_id_1 = arm_sme.cast_vector_to_tile %tile : vector<[16]x[16]xi8> to i8
+ return %tile_id_1 : i8
+}
+
+// -----
+
+// CHECK-LABEL: @cast_tile_to_vector__cast_vector_to_tile
+// CHECK-SAME: %[[TILE:.*]]: vector<[16]x[16]xi8>
+func.func @cast_tile_to_vector__cast_vector_to_tile(%tile_0 : vector<[16]x[16]xi8>) -> vector<[16]x[16]xi8> {
+ // CHECK-NOT: arm_sme.cast_vector_to_tile
+ // CHECK-NOT: arm_sme.cast_tile_to_vector
+ // CHECK-NEXT: return %[[TILE]] : vector<[16]x[16]xi8>
+ %tile_id = arm_sme.cast_vector_to_tile %tile_0 : vector<[16]x[16]xi8> to i8
+ %tile_1 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi8>
+ return %tile_1 : vector<[16]x[16]xi8>
+}
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
new file mode 100644
index 00000000000000..1609ed39e64164
--- /dev/null
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -0,0 +1,73 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+// -----
+
+func.func @arm_sme_cast_tile_to_vector__bad_tile_id_bitwidth(%tile_id : i8) -> vector<[8]x[8]xi16> {
+ // expected-error at +1 {{op failed to verify that `tile_id` has the same number of bits as elements in `vector`}}
+ %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[8]x[8]xi16>
+ return %0 : vector<[8]x[8]xi16>
+}
+
+// -----
+
+func.func @arm_sme_cast_tile_to_vector__bad_vector_type_rank_1(%tile_id : i8) -> vector<[16]xi8> {
+ // expected-error at +1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]xi8>'}}
+ %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]xi8>
+ return %0 : vector<[16]xi8>
+}
+
+// -----
+
+func.func @arm_sme_cast_tile_to_vector__bad_vector_type_i4(%tile_id : i8) -> vector<[16]x[16]xi4> {
+ // expected-error at +1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]x[16]xi4>'}}
+ %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi4>
+ return %0 : vector<[16]x[16]xi4>
+}
+
+// -----
+
+func.func @arm_sme_cast_tile_to_vector__bad_vector_type_non_scalable_dim_0(%tile_id : i8) -> vector<16x[16]xi8> {
+ // expected-error at +1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<16x[16]xi8>'}}
+ %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<16x[16]xi8>
+ return %0 : vector<16x[16]xi8>
+}
+
+// -----
+
+func.func @arm_sme_cast_tile_to_vector__bad_vector_type_non_scalable_dim_1(%tile_id : i8) -> vector<[16]x16xi8> {
+ // expected-error at +1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]x16xi8>'}}
+ %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x16xi8>
+ return %0 : vector<[16]x16xi8>
+}
+
+// -----
+
+func.func @arm_sme_cast_tile_to_vector_bad_shape(%tile_id : i8) -> vector<[4]x[16]xi8> {
+ // expected-error at +1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[4]x[16]xi8>'}}
+ %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[4]x[16]xi8>
+ return %0 : vector<[4]x[16]xi8>
+}
+
+// -----
+
+func.func @arm_sme_cast_vector_to_tile__bad_tile_id_bitwidth(%vector : vector<[1]x[1]xi128>) -> i32 {
+ // expected-error at +1 {{op failed to verify that `tile_id` has the same number of bits as elements in `vector`}}
+ %0 = arm_sme.cast_vector_to_tile %vector : vector<[1]x[1]xi128> to i32
+ return %0 : i32
+}
+
+// -----
+
+func.func @arm_sme_cast_vector_to_tile__bad_rank_1d(%vector : vector<[16]xi8>) -> i8 {
+ // expected-error at +1 {{op operand #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]xi8>'}}
+ %0 = arm_sme.cast_vector_to_tile %vector : vector<[16]xi8> to i8
+ return %0 : i8
+}
+
+// -----
+
+func.func @arm_sme_get_tile_id__bad_type() -> i1 {
+ // expected-error at +1 {{op result #0 must be 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 128-bit signless integer}}
+ %0 = arm_sme.get_tile_id : i1
+ return %0 : i1
+}
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
new file mode 100644
index 00000000000000..6256b5bc062d31
--- /dev/null
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -0,0 +1,185 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s | mlir-opt | FileCheck %s
+
+// -----
+
+func.func @arm_sme_cast_tile_to_vector_i8(%tile_id : i8) -> vector<[16]x[16]xi8> {
+ // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i8 to vector<[16]x[16]xi8>
+ %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi8>
+ return %0 : vector<[16]x[16]xi8>
+}
+
+// -----
+
+func.func @arm_sme_cast_tile_to_vector_i16(%tile_id : i16) -> vector<[8]x[8]xi16> {
+ // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i16 to vector<[8]x[8]xi16>
+ %0 = arm_sme.cast_tile_to_vector %tile_id : i16 to vector<[8]x[8]xi16>
+ return %0 : vector<[8]x[8]xi16>
+}
+
+// -----
+
+func.func @arm_sme_cast_tile_to_vector_i32(%tile_id : i32) -> vector<[4]x[4]xi32> {
+ // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i32 to vector<[4]x[4]xi32>
+ %0 = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
+ return %0 : vector<[4]x[4]xi32>
+}
+
+// -----
+
+func.func @arm_sme_cast_tile_to_vector_i64(%tile_id : i64) -> vector<[2]x[2]xi64> {
+ // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i64 to vector<[2]x[2]xi64>
+ %0 = arm_sme.cast_tile_to_vector %tile_id : i64 to vector<[2]x[2]xi64>
+ return %0 : vector<[2]x[2]xi64>
+}
+
+// -----
+
+func.func @arm_sme_cast_tile_to_vector_i128(%tile_id : i128) -> vector<[1]x[1]xi128> {
+ // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i128 to vector<[1]x[1]xi128>
+ %0 = arm_sme.cast_tile_to_vector %tile_id : i128 to vector<[1]x[1]xi128>
+ return %0 : vector<[1]x[1]xi128>
+}
+
+// -----
+
+func.func @arm_sme_cast_tile_to_vector_f16(%tile_id : i16) -> vector<[8]x[8]xf16> {
+ // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i16 to vector<[8]x[8]xf16>
+ %0 = arm_sme.cast_tile_to_vector %tile_id : i16 to vector<[8]x[8]xf16>
+ return %0 : vector<[8]x[8]xf16>
+}
+
+// -----
+
+func.func @arm_sme_cast_tile_to_vector_bf16(%tile_id : i16) -> vector<[8]x[8]xbf16> {
+ // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i16 to vector<[8]x[8]xbf16>
+ %0 = arm_sme.cast_tile_to_vector %tile_id : i16 to vector<[8]x[8]xbf16>
+ return %0 : vector<[8]x[8]xbf16>
+}
+
+// -----
+
+func.func @arm_sme_cast_tile_to_vector_f32(%tile_id : i32) -> vector<[4]x[4]xf32> {
+ // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i32 to vector<[4]x[4]xf32>
+ %0 = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xf32>
+ return %0 : vector<[4]x[4]xf32>
+}
+
+// -----
+
+func.func @arm_sme_cast_tile_to_vector_f64(%tile_id : i64) -> vector<[2]x[2]xf64> {
+ // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i64 to vector<[2]x[2]xf64>
+ %0 = arm_sme.cast_tile_to_vector %tile_id : i64 to vector<[2]x[2]xf64>
+ return %0 : vector<[2]x[2]xf64>
+}
+
+// -----
+
+func.func @arm_sme_cast_vector_to_tile_i8(%vector : vector<[16]x[16]xi8>) -> i8 {
+ // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[16]x[16]xi8> to i8
+ %0 = arm_sme.cast_vector_to_tile %vector : vector<[16]x[16]xi8> to i8
+ return %0 : i8
+}
+
+// -----
+
+func.func @arm_sme_cast_vector_to_tile_i16(%vector : vector<[8]x[8]xi16>) -> i16 {
+ // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[8]x[8]xi16> to i16
+ %0 = arm_sme.cast_vector_to_tile %vector : vector<[8]x[8]xi16> to i16
+ return %0 : i16
+}
+
+// -----
+
+func.func @arm_sme_cast_vector_to_tile_i32(%vector : vector<[4]x[4]xi32>) -> i32 {
+ // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[4]x[4]xi32> to i32
+ %0 = arm_sme.cast_vector_to_tile %vector : vector<[4]x[4]xi32> to i32
+ return %0 : i32
+}
+
+// -----
+
+func.func @arm_sme_cast_vector_to_tile_i64(%vector : vector<[2]x[2]xi64>) -> i64 {
+ // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[2]x[2]xi64> to i64
+ %0 = arm_sme.cast_vector_to_tile %vector : vector<[2]x[2]xi64> to i64
+ return %0 : i64
+}
+
+// -----
+
+func.func @arm_sme_cast_vector_to_tile_i128(%vector : vector<[1]x[1]xi128>) -> i128 {
+ // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[1]x[1]xi128> to i128
+ %0 = arm_sme.cast_vector_to_tile %vector : vector<[1]x[1]xi128> to i128
+ return %0 : i128
+}
+
+// -----
+
+func.func @arm_sme_cast_vector_to_tile_f16(%vector : vector<[8]x[8]xf16>) -> i16 {
+ // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[8]x[8]xf16> to i16
+ %0 = arm_sme.cast_vector_to_tile %vector : vector<[8]x[8]xf16> to i16
+ return %0 : i16
+}
+
+// -----
+
+func.func @arm_sme_cast_vector_to_tile_bf16(%vector : vector<[8]x[8]xbf16>) -> i16 {
+ // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[8]x[8]xbf16> to i16
+ %0 = arm_sme.cast_vector_to_tile %vector : vector<[8]x[8]xbf16> to i16
+ return %0 : i16
+}
+
+// -----
+
+func.func @arm_sme_cast_vector_to_tile_f32(%vector : vector<[4]x[4]xf32>) -> i32 {
+ // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[4]x[4]xf32> to i32
+ %0 = arm_sme.cast_vector_to_tile %vector : vector<[4]x[4]xf32> to i32
+ return %0 : i32
+}
+
+// -----
+
+func.func @arm_sme_cast_vector_to_tile_f64(%vector : vector<[2]x[2]xf64>) -> i64 {
+ // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[2]x[2]xf64> to i64
+ %0 = arm_sme.cast_vector_to_tile %vector : vector<[2]x[2]xf64> to i64
+ return %0 : i64
+}
+
+// -----
+
+func.func @arm_sme_get_tile_id_i8() -> i8 {
+ // CHECK: arm_sme.get_tile_id : i8
+ %0 = arm_sme.get_tile_id : i8
+ return %0 : i8
+}
+
+// -----
+
+func.func @arm_sme_get_tile_id_i16() -> i16 {
+ // CHECK: arm_sme.get_tile_id : i16
+ %0 = arm_sme.get_tile_id : i16
+ return %0 : i16
+}
+
+// -----
+
+func.func @arm_sme_get_tile_id_i32() -> i32 {
+ // CHECK: arm_sme.get_tile_id : i32
+ %0 = arm_sme.get_tile_id : i32
+ return %0 : i32
+}
+
+// -----
+
+func.func @arm_sme_get_tile_id_i64() -> i64 {
+ // CHECK: arm_sme.get_tile_id : i64
+ %0 = arm_sme.get_tile_id : i64
+ return %0 : i64
+}
+
+// -----
+
+func.func @arm_sme_get_tile_id_i128() -> i128 {
+ // CHECK: arm_sme.get_tile_id : i128
+ %0 = arm_sme.get_tile_id : i128
+ return %0 : i128
+}
More information about the Mlir-commits
mailing list