[Mlir-commits] [mlir] 6ff9761 - [mlir][ArmSME] Add custom get_tile_id and cast ops

Cullen Rhodes llvmlistbot at llvm.org
Tue Jul 18 00:42:08 PDT 2023


Author: Cullen Rhodes
Date: 2023-07-18T07:41:45Z
New Revision: 6ff9761a69df24d33f7c1047a956911bf28a754b

URL: https://github.com/llvm/llvm-project/commit/6ff9761a69df24d33f7c1047a956911bf28a754b
DIFF: https://github.com/llvm/llvm-project/commit/6ff9761a69df24d33f7c1047a956911bf28a754b.diff

LOG: [mlir][ArmSME] Add custom get_tile_id and cast ops

This patch adds three new custom ops to the ArmSME dialect:

  * arm_sme.get_tile_id - returns a scalar integer representing an SME
    "virtual tile" that is not in use.
  * arm_sme.cast_tile_to_vector - casts from a tile id to a 2-d scalable
    vector type, which represents an SME "virtual tile".
  * arm_sme.cast_vector_to_tile - casts from a 2-d scalable vector type,
    which represents an SME "virtual tile", to a tile id.

The 'arm_sme.get_tile_id' op currently only supports tile 0, a follow-up
patch will implement proper tile allocation. A further follow-up patch
will demonstrate load/store to/from ZA using these ops.

See the op descriptions for further details and examples.

Thanks to @paulwalker-arm and @awarzynski for helping drive this.

Reviewed By: awarzynski, dcaballe

Differential Revision: https://reviews.llvm.org/D154941

Added: 
    mlir/test/Dialect/ArmSME/canonicalize.mlir
    mlir/test/Dialect/ArmSME/invalid.mlir
    mlir/test/Dialect/ArmSME/roundtrip.mlir

Modified: 
    mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
    mlir/include/mlir/IR/OpBase.td
    mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
    mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 140ed51b101b97..19283155b21714 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -36,6 +36,166 @@ def ArmSME_Dialect : Dialect {
   let dependentDialects = ["scf::SCFDialect"];
 }
 
+//===----------------------------------------------------------------------===//
+// ArmSME type definitions
+//===----------------------------------------------------------------------===//
+
+class SMETileType<Type datatype, list<int> dims, string description>
+  : ShapedContainerType<[datatype],
+      And<[IsVectorOfRankPred<[2]>, allDimsScalableVectorTypePred,
+           IsVectorOfShape<dims>]>,
+  description>;
+
+def nxnxv16i8  : SMETileType<I8,   [16, 16], "vector<[16]x[16]xi8>">;
+def nxnxv8i16  : SMETileType<I16,  [8,  8 ], "vector<[8]x[8]xi16>">;
+def nxnxv4i32  : SMETileType<I32,  [4,  4 ], "vector<[4]x[4]xi32>">;
+def nxnxv2i64  : SMETileType<I64,  [2,  2 ], "vector<[2]x[2]xi64>">;
+def nxnxv1i128 : SMETileType<I128, [1,  1 ], "vector<[1]x[1]xi128>">;
+
+def nxnxv8f16  : SMETileType<F16,  [8,  8 ], "vector<[8]x[8]xf16>">;
+def nxnxv8bf16 : SMETileType<BF16, [8,  8 ], "vector<[8]x[8]xbf16>">;
+def nxnxv4f32  : SMETileType<F32,  [4,  4 ], "vector<[4]x[4]xf32>">;
+def nxnxv2f64  : SMETileType<F64,  [2,  2 ], "vector<[2]x[2]xf64>">;
+
+def SMETile : AnyTypeOf<[nxnxv16i8, nxnxv8i16, nxnxv4i32, nxnxv2i64, nxnxv1i128,
+                         nxnxv8f16, nxnxv8bf16, nxnxv4f32, nxnxv2f64]>;
+
+// A type constraint that verifies the bitwidth of the scalar integer returned
+// from 'arm_sme.get_tile_id' matches the element bitwidth of a "virtual tile".
+def TileElementWidthMatchesTileID : TypesMatchWith<
+  "`tile_id` has the same number of bits as elements in `vector`",
+  "vector", "tile_id",
+  "IntegerType::get("
+      "$_self.getContext(),"
+      "::llvm::isa<IntegerType>(::llvm::cast<VectorType>($_self).getElementType())"
+          "? ::llvm::cast<IntegerType>("
+                  "::llvm::cast<VectorType>($_self).getElementType())"
+                  ".getWidth()"
+          ": ::llvm::cast<FloatType>("
+                  "::llvm::cast<VectorType>($_self).getElementType())"
+                  ".getWidth())">;
+
+//===----------------------------------------------------------------------===//
+// ArmSME op definitions
+//===----------------------------------------------------------------------===//
+
+class ArmSME_Op<string mnemonic, list<Trait> traits = []> :
+  Op<ArmSME_Dialect, mnemonic, traits> {}
+
+def CastTileToVector : ArmSME_Op<"cast_tile_to_vector", [Pure, TileElementWidthMatchesTileID]> {
+  let summary = "Cast from tile id to 2-d scalable vector type";
+  let description = [{
+    A `cast_tile_to_vector` operation does a cast from a tile id to a 2-d
+    scalable vector type, which represents an SME "virtual tile". This would
+    normally be used when lowering operations that return "virtual tile" vector
+    types to model the output. This is required to preserve dataflow as SME
+    intrinsics have no return values.
+
+    Example:
+
+    Input:
+    ```mlir
+    %tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+    vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+    ```
+
+    After lowering `vector.load`:
+    ```mlir
+    %tile_id = arm_sme.get_tile_id : i32
+    scf.for %vnum = %c0 to %num_vectors step %c1 {
+      // ...
+      "arm_sme.intr.ld1w.horiz"(%pg, %ptr, %tile_id, %vnum) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
+    }
+    %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
+    vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+    ```
+
+    In the example above, the `vector.load` can't be replaced with an SME
+    intrinsic that has no outputs since it is used by the `vector.store`.
+    However, by inserting a `cast_tile_to_vector` op after the load intrinsics
+    the `vector.load` can be replaced. This enables "local" rewrites on
+    individual vector ops, rather than "global" rewrites that would have to
+    look at the vector op uses and also lower them.
+
+    Canonicalization will look through `arm_sme.cast_tile_to_vector` and fold
+    the cast away if it comes from a `arm_sme.cast_vector_to_tile`.
+  }];
+  let arguments = (ins AnyTypeOf<[I8, I16, I32, I64, I128]>:$tile_id);
+  let results = (outs SMETile:$vector);
+  let assemblyFormat =
+    "$tile_id attr-dict `:` type($tile_id) `to` type($vector)";
+  let hasCanonicalizeMethod = 1;
+}
+
+def CastVectorToTile : ArmSME_Op<"cast_vector_to_tile", [Pure, TileElementWidthMatchesTileID]> {
+  let summary = "Cast from 2-d scalable vector type to tile id";
+  let description = [{
+    A `cast_vector_to_tile` operation does a cast from a 2-d scalable vector
+    type, which represents an SME "virtual tile", to a tile id. This is
+    required to preserve dataflow as the SME intrinsics have no return values.
+
+    Example:
+
+    Input:
+    ```mlir
+    %tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+    vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+    ```
+
+    After lowering `vector.store`:
+    ```mlir
+    %tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+    scf.for %vnum = %c0 to %num_vectors step %c1 {
+      // ...
+      %tile_id = arm_sme.cast_vector_to_tile %tile : (vector<[4]x[4]xi32>) -> i32
+      "arm_sme.intr.st1w.horiz"(%pg, %ptr, %tile_id, %vnum) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
+    }
+    ```
+
+    Canonicalization will look through `arm_sme.cast_vector_to_tile` and fold
+    the cast away if it comes from a `arm_sme.cast_tile_to_vector`.
+  }];
+  let arguments = (ins SMETile:$vector);
+  let results = (outs AnyTypeOf<[I8, I16, I32, I64, I128]>:$tile_id);
+  let assemblyFormat =
+    "$vector attr-dict `:` type($vector) `to` type($tile_id)";
+  let hasCanonicalizeMethod = 1;
+}
+
+def GetTileID : ArmSME_Op<"get_tile_id", [Pure]> {
+  let summary = "Returns an SME \"virtual tile\" id";
+  let description = [{
+    A `get_tile_id` operation returns a scalar integer representing an SME
+    "virtual tile" id. The bitwidth of the scalar indicates the element
+    bitwidth of the "virtual tile".
+
+    The scope of a tile id is a function and cannot be passed or returned from
+    functions.
+
+    Example:
+    ```mlir
+    // Allocate and return an 8-bit element "virtual tile" id
+    %za0_b = arm_sme.get_tile_id : i8
+    ```
+
+    Example:
+    ```
+    // Allocate and return two 16-bit element "virtual tile" ids
+    %za0_h = arm_sme.get_tile_id : i16
+    %za1_h = arm_sme.get_tile_id : i16
+    ```
+
+    Example:
+    ```
+    // Allocate and return an 128-bit element "virtual tile" id
+    %za0_q = arm_sme.get_tile_id : i128
+    ```
+  }];
+
+  let results = (outs AnyTypeOf<[I8, I16, I32, I64, I128]>:$tile_id);
+  let assemblyFormat = "attr-dict `:` type($tile_id)";
+}
+
 //===----------------------------------------------------------------------===//
 // ArmSME Intrinsic op definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index ad10fdb4255c04..28a1716228e8f0 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -308,6 +308,12 @@ def IsFixedVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
 def IsScalableVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
                                    ::llvm::cast<VectorType>($_self).isScalable()}]>;
 
+// Whether a type is a VectorType and all dimensions are scalable.
+def allDimsScalableVectorTypePred : And<[
+  IsVectorTypePred,
+  CPred<[{::llvm::cast<::mlir::VectorType>($_self).allDimsScalable()}]>
+]>;
+
 // Whether a type is a TensorType.
 def IsTensorTypePred : CPred<"::llvm::isa<::mlir::TensorType>($_self)">;
 
@@ -488,6 +494,7 @@ def I8  : I<8>;
 def I16 : I<16>;
 def I32 : I<32>;
 def I64 : I<64>;
+def I128 : I<128>;
 
 // Any signed integer type irrespective of its width.
 def AnySignedInteger : Type<
@@ -745,6 +752,10 @@ class IsScalableVectorOfLengthPred<list<int> allowedLengths> :
                            == }]
                          # allowedlength>)>]>;
 
+// Whether the shape of a vector matches the given `shape` list.
+class IsVectorOfShape<list<int> shape>
+  : CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>({" # !interleave(shape, ", ") # "})">;
+
 // Any vector where the number of elements is from the given
 // `allowedLengths` list
 class VectorOfLength<list<int> allowedLengths> : Type<

diff  --git a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
index 7f5aa61aa327eb..750627421215df 100644
--- a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
@@ -34,3 +34,23 @@ void ArmSMEDialect::initialize() {
 #include "mlir/Dialect/ArmSME/IR/ArmSME.cpp.inc"
       >();
 }
+
+// cast_vector_to_tile(cast_tile_to_vector(tile_id)) -> tile_id
+LogicalResult CastVectorToTile::canonicalize(CastVectorToTile op,
+                                             PatternRewriter &rewriter) {
+  if (auto castTileToVectorOp = op.getVector().getDefiningOp<CastTileToVector>()) {
+    op.replaceAllUsesWith(castTileToVectorOp.getTileId());
+    return success();
+  }
+  return failure();
+}
+
+// cast_tile_to_vector(cast_vector_to_tile(tile)) -> tile
+LogicalResult CastTileToVector::canonicalize(CastTileToVector op,
+                                             PatternRewriter &rewriter) {
+  if (auto castVectorToTileOp = op.getTileId().getDefiningOp<CastVectorToTile>()) {
+    op.replaceAllUsesWith(castVectorToTileOp.getVector());
+    return success();
+  }
+  return failure();
+}

diff  --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index 2eb061da49f440..aa59aa5b2b3585 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -7,9 +7,11 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
 #include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 
 using namespace mlir;
@@ -43,6 +45,17 @@ struct DisableZAPattern : public OpRewritePattern<func::ReturnOp> {
     return success();
   }
 };
+
+struct GetTileIDConversion : public ConvertOpToLLVMPattern<GetTileID> {
+  using ConvertOpToLLVMPattern<GetTileID>::ConvertOpToLLVMPattern;
+  LogicalResult
+  matchAndRewrite(GetTileID op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // TODO: implement tile allocation, currently only tile 0 is supported.
+    rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(op, rewriter.getI32Type(), 0);
+    return success();
+  }
+};
 } // namespace
 
 void mlir::populateArmSMELegalizeForLLVMExportPatterns(
@@ -52,9 +65,11 @@ void mlir::populateArmSMELegalizeForLLVMExportPatterns(
 
 void mlir::configureArmSMELegalizeForExportTarget(
     LLVMConversionTarget &target) {
-  target.addLegalOp<scf::ForOp, scf::YieldOp, arm_sme::aarch64_sme_zero,
+  target.addLegalOp<scf::ForOp, scf::YieldOp, arm_sme::CastTileToVector,
+                    arm_sme::CastVectorToTile, arm_sme::aarch64_sme_zero,
                     arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_za_enable,
                     arm_sme::aarch64_sme_za_disable>();
+  target.addLegalOp<GetTileID>();
 
   // Mark 'func.func' ops as legal if either:
   //   1. no 'arm_za' function attribute is present.

diff  --git a/mlir/test/Dialect/ArmSME/canonicalize.mlir b/mlir/test/Dialect/ArmSME/canonicalize.mlir
new file mode 100644
index 00000000000000..06bbd3050fdece
--- /dev/null
+++ b/mlir/test/Dialect/ArmSME/canonicalize.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt -canonicalize -split-input-file -verify-diagnostics %s | mlir-opt | FileCheck %s
+
+// -----
+
+// CHECK-LABEL: @cast_vector_to_tile__cast_tile_to_vector
+// CHECK-SAME: %[[TILE_ID:.*]]: i8
+func.func @cast_vector_to_tile__cast_tile_to_vector(%tile_id_0 : i8) -> i8 {
+  // CHECK-NOT: arm_sme.cast_tile_to_vector
+  // CHECK-NOT: arm_sme.cast_vector_to_tile
+  // CHECK-NEXT: return %[[TILE_ID]] : i8
+  %tile = arm_sme.cast_tile_to_vector %tile_id_0 : i8 to vector<[16]x[16]xi8>
+  %tile_id_1 = arm_sme.cast_vector_to_tile %tile : vector<[16]x[16]xi8> to i8
+  return %tile_id_1 : i8
+}
+
+// -----
+
+// CHECK-LABEL: @cast_tile_to_vector__cast_vector_to_tile
+// CHECK-SAME: %[[TILE:.*]]: vector<[16]x[16]xi8>
+func.func @cast_tile_to_vector__cast_vector_to_tile(%tile_0 : vector<[16]x[16]xi8>) -> vector<[16]x[16]xi8> {
+  // CHECK-NOT: arm_sme.cast_vector_to_tile
+  // CHECK-NOT: arm_sme.cast_tile_to_vector
+  // CHECK-NEXT: return %[[TILE]] : vector<[16]x[16]xi8>
+  %tile_id = arm_sme.cast_vector_to_tile %tile_0 : vector<[16]x[16]xi8> to i8
+  %tile_1 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi8>
+  return %tile_1 : vector<[16]x[16]xi8>
+}

diff  --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
new file mode 100644
index 00000000000000..1609ed39e64164
--- /dev/null
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -0,0 +1,73 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+// -----
+
+func.func @arm_sme_cast_tile_to_vector__bad_tile_id_bitwidth(%tile_id : i8) -> vector<[8]x[8]xi16> {
+  // expected-error at +1 {{op failed to verify that `tile_id` has the same number of bits as elements in `vector`}}
+  %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[8]x[8]xi16>
+  return %0 : vector<[8]x[8]xi16>
+}
+
+// -----
+
+func.func @arm_sme_cast_tile_to_vector__bad_vector_type_rank_1(%tile_id : i8) -> vector<[16]xi8> {
+  // expected-error at +1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]xi8>'}}
+  %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]xi8>
+  return %0 : vector<[16]xi8>
+}
+
+// -----
+
+func.func @arm_sme_cast_tile_to_vector__bad_vector_type_i4(%tile_id : i8) -> vector<[16]x[16]xi4> {
+  // expected-error at +1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]x[16]xi4>'}}
+  %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi4>
+  return %0 : vector<[16]x[16]xi4>
+}
+
+// -----
+
+func.func @arm_sme_cast_tile_to_vector__bad_vector_type_non_scalable_dim_0(%tile_id : i8) -> vector<16x[16]xi8> {
+  // expected-error at +1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<16x[16]xi8>'}}
+  %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<16x[16]xi8>
+  return %0 : vector<16x[16]xi8>
+}
+
+// -----
+
+func.func @arm_sme_cast_tile_to_vector__bad_vector_type_non_scalable_dim_1(%tile_id : i8) -> vector<[16]x16xi8> {
+  // expected-error at +1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]x16xi8>'}}
+  %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x16xi8>
+  return %0 : vector<[16]x16xi8>
+}
+
+// -----
+
+func.func @arm_sme_cast_tile_to_vector_bad_shape(%tile_id : i8) -> vector<[4]x[16]xi8> {
+  // expected-error at +1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[4]x[16]xi8>'}}
+  %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[4]x[16]xi8>
+  return %0 : vector<[4]x[16]xi8>
+}
+
+// -----
+
+func.func @arm_sme_cast_vector_to_tile__bad_tile_id_bitwidth(%vector : vector<[1]x[1]xi128>) -> i32 {
+  // expected-error at +1 {{op failed to verify that `tile_id` has the same number of bits as elements in `vector`}}
+  %0 = arm_sme.cast_vector_to_tile %vector : vector<[1]x[1]xi128> to i32
+  return %0 : i32
+}
+
+// -----
+
+func.func @arm_sme_cast_vector_to_tile__bad_rank_1d(%vector : vector<[16]xi8>) -> i8 {
+  // expected-error at +1 {{op operand #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]xi8>'}}
+  %0 = arm_sme.cast_vector_to_tile %vector : vector<[16]xi8> to i8
+  return %0 : i8
+}
+
+// -----
+
+func.func @arm_sme_get_tile_id__bad_type() -> i1 {
+  // expected-error at +1 {{op result #0 must be 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 128-bit signless integer}}
+  %0 = arm_sme.get_tile_id : i1
+  return %0 : i1
+}

diff  --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
new file mode 100644
index 00000000000000..6256b5bc062d31
--- /dev/null
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -0,0 +1,185 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s | mlir-opt | FileCheck %s
+
+// -----
+
+func.func @arm_sme_cast_tile_to_vector_i8(%tile_id : i8) -> vector<[16]x[16]xi8> {
+  // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i8 to vector<[16]x[16]xi8>
+  %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi8>
+  return %0 : vector<[16]x[16]xi8>
+}
+
+// -----
+
+func.func @arm_sme_cast_tile_to_vector_i16(%tile_id : i16) -> vector<[8]x[8]xi16> {
+  // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i16 to vector<[8]x[8]xi16>
+  %0 = arm_sme.cast_tile_to_vector %tile_id : i16 to vector<[8]x[8]xi16>
+  return %0 : vector<[8]x[8]xi16>
+}
+
+// -----
+
+func.func @arm_sme_cast_tile_to_vector_i32(%tile_id : i32) -> vector<[4]x[4]xi32> {
+  // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i32 to vector<[4]x[4]xi32>
+  %0 = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
+  return %0 : vector<[4]x[4]xi32>
+}
+
+// -----
+
+func.func @arm_sme_cast_tile_to_vector_i64(%tile_id : i64) -> vector<[2]x[2]xi64> {
+  // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i64 to vector<[2]x[2]xi64>
+  %0 = arm_sme.cast_tile_to_vector %tile_id : i64 to vector<[2]x[2]xi64>
+  return %0 : vector<[2]x[2]xi64>
+}
+
+// -----
+
+func.func @arm_sme_cast_tile_to_vector_i128(%tile_id : i128) -> vector<[1]x[1]xi128> {
+  // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i128 to vector<[1]x[1]xi128>
+  %0 = arm_sme.cast_tile_to_vector %tile_id : i128 to vector<[1]x[1]xi128>
+  return %0 : vector<[1]x[1]xi128>
+}
+
+// -----
+
+func.func @arm_sme_cast_tile_to_vector_f16(%tile_id : i16) -> vector<[8]x[8]xf16> {
+  // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i16 to vector<[8]x[8]xf16>
+  %0 = arm_sme.cast_tile_to_vector %tile_id : i16 to vector<[8]x[8]xf16>
+  return %0 : vector<[8]x[8]xf16>
+}
+
+// -----
+
+func.func @arm_sme_cast_tile_to_vector_bf16(%tile_id : i16) -> vector<[8]x[8]xbf16> {
+  // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i16 to vector<[8]x[8]xbf16>
+  %0 = arm_sme.cast_tile_to_vector %tile_id : i16 to vector<[8]x[8]xbf16>
+  return %0 : vector<[8]x[8]xbf16>
+}
+
+// -----
+
+func.func @arm_sme_cast_tile_to_vector_f32(%tile_id : i32) -> vector<[4]x[4]xf32> {
+  // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i32 to vector<[4]x[4]xf32>
+  %0 = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xf32>
+  return %0 : vector<[4]x[4]xf32>
+}
+
+// -----
+
+func.func @arm_sme_cast_tile_to_vector_f64(%tile_id : i64) -> vector<[2]x[2]xf64> {
+  // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i64 to vector<[2]x[2]xf64>
+  %0 = arm_sme.cast_tile_to_vector %tile_id : i64 to vector<[2]x[2]xf64>
+  return %0 : vector<[2]x[2]xf64>
+}
+
+// -----
+
+func.func @arm_sme_cast_vector_to_tile_i8(%vector : vector<[16]x[16]xi8>) -> i8 {
+  // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[16]x[16]xi8> to i8
+  %0 = arm_sme.cast_vector_to_tile %vector : vector<[16]x[16]xi8> to i8
+  return %0 : i8
+}
+
+// -----
+
+func.func @arm_sme_cast_vector_to_tile_i16(%vector : vector<[8]x[8]xi16>) -> i16 {
+  // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[8]x[8]xi16> to i16
+  %0 = arm_sme.cast_vector_to_tile %vector : vector<[8]x[8]xi16> to i16
+  return %0 : i16
+}
+
+// -----
+
+func.func @arm_sme_cast_vector_to_tile_i32(%vector : vector<[4]x[4]xi32>) -> i32 {
+  // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[4]x[4]xi32> to i32
+  %0 = arm_sme.cast_vector_to_tile %vector : vector<[4]x[4]xi32> to i32
+  return %0 : i32
+}
+
+// -----
+
+func.func @arm_sme_cast_vector_to_tile_i64(%vector : vector<[2]x[2]xi64>) -> i64 {
+  // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[2]x[2]xi64> to i64
+  %0 = arm_sme.cast_vector_to_tile %vector : vector<[2]x[2]xi64> to i64
+  return %0 : i64
+}
+
+// -----
+
+func.func @arm_sme_cast_vector_to_tile_i128(%vector : vector<[1]x[1]xi128>) -> i128 {
+  // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[1]x[1]xi128> to i128
+  %0 = arm_sme.cast_vector_to_tile %vector : vector<[1]x[1]xi128> to i128
+  return %0 : i128
+}
+
+// -----
+
+func.func @arm_sme_cast_vector_to_tile_f16(%vector : vector<[8]x[8]xf16>) -> i16 {
+  // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[8]x[8]xf16> to i16
+  %0 = arm_sme.cast_vector_to_tile %vector : vector<[8]x[8]xf16> to i16
+  return %0 : i16
+}
+
+// -----
+
+func.func @arm_sme_cast_vector_to_tile_bf16(%vector : vector<[8]x[8]xbf16>) -> i16 {
+  // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[8]x[8]xbf16> to i16
+  %0 = arm_sme.cast_vector_to_tile %vector : vector<[8]x[8]xbf16> to i16
+  return %0 : i16
+}
+
+// -----
+
+func.func @arm_sme_cast_vector_to_tile_f32(%vector : vector<[4]x[4]xf32>) -> i32 {
+  // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[4]x[4]xf32> to i32
+  %0 = arm_sme.cast_vector_to_tile %vector : vector<[4]x[4]xf32> to i32
+  return %0 : i32
+}
+
+// -----
+
+func.func @arm_sme_cast_vector_to_tile_f64(%vector : vector<[2]x[2]xf64>) -> i64 {
+  // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[2]x[2]xf64> to i64
+  %0 = arm_sme.cast_vector_to_tile %vector : vector<[2]x[2]xf64> to i64
+  return %0 : i64
+}
+
+// -----
+
+func.func @arm_sme_get_tile_id_i8() -> i8 {
+  // CHECK: arm_sme.get_tile_id : i8
+  %0 = arm_sme.get_tile_id : i8
+  return %0 : i8
+}
+
+// -----
+
+func.func @arm_sme_get_tile_id_i16() -> i16 {
+  // CHECK: arm_sme.get_tile_id : i16
+  %0 = arm_sme.get_tile_id : i16
+  return %0 : i16
+}
+
+// -----
+
+func.func @arm_sme_get_tile_id_i32() -> i32 {
+  // CHECK: arm_sme.get_tile_id : i32
+  %0 = arm_sme.get_tile_id : i32
+  return %0 : i32
+}
+
+// -----
+
+func.func @arm_sme_get_tile_id_i64() -> i64 {
+  // CHECK: arm_sme.get_tile_id : i64
+  %0 = arm_sme.get_tile_id : i64
+  return %0 : i64
+}
+
+// -----
+
+func.func @arm_sme_get_tile_id_i128() -> i128 {
+  // CHECK: arm_sme.get_tile_id : i128
+  %0 = arm_sme.get_tile_id : i128
+  return %0 : i128
+}


        


More information about the Mlir-commits mailing list