[Mlir-commits] [mlir] a1b2ace - [mlir][ArmSME] Add optional padding and mask operands to tile_load (#69195)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Oct 30 05:12:56 PDT 2023
Author: Cullen Rhodes
Date: 2023-10-30T12:12:52Z
New Revision: a1b2ace137385388bf9bd7ea4b6df3ff298900f6
URL: https://github.com/llvm/llvm-project/commit/a1b2ace137385388bf9bd7ea4b6df3ff298900f6
DIFF: https://github.com/llvm/llvm-project/commit/a1b2ace137385388bf9bd7ea4b6df3ff298900f6.diff
LOG: [mlir][ArmSME] Add optional padding and mask operands to tile_load (#69195)
Padding and mask are optional, but if one is specified both must be
specified. This is consistent with vector.transfer_read.
Added:
Modified:
mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
mlir/test/Dialect/ArmSME/invalid.mlir
mlir/test/Dialect/ArmSME/roundtrip.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 9b9dbff10ea2da6..b30d0fdb866bd23 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -231,7 +231,26 @@ def ZeroOp : ArmSME_Op<"zero", [Pure]> {
let assemblyFormat = "attr-dict `:` type($res)";
}
-def TileLoadOp : ArmSME_Op<"tile_load"> {
+def TileLoadOp : ArmSME_Op<"tile_load", [
+ AttrSizedOperandSegments,
+ OptionalTypesMatchWith<
+ "padding type matches element type of result",
+ "result", "padding",
+ "::llvm::cast<VectorType>($_self).getElementType()"
+ >,
+ OptionalTypesMatchWith<
+ "mask has i1 element type and same shape as result",
+ "result", "mask",
+ "VectorType("
+ "VectorType::Builder("
+ "::llvm::cast<mlir::VectorType>($_self)"
+ ").setElementType(IntegerType::get($_self.getContext(), 1)))"
+ >,
+ PredOpTrait<
+ "both `padding` and `mask` should be provided or neither",
+ CPred<"bool(getPadding()) == bool(getMask())">
+ >,
+]> {
let summary = "Tile load operation";
let description = [{
Loads a 2D SME "virtual tile" from memory defined by a base and indices,
@@ -242,6 +261,16 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
dimensions, since the operation is scalable, and the element type must be a
scalar that matches the element type of the result.
+ An optional SSA value `padding` of the same elemental type as the MemRef is
+ provided to specify a fallback value in the case of masking.
+
+ An optional SSA value `mask` may be specified to mask out elements read
+ from the MemRef. The `mask` type is an `i1` vector with a shape that
+ matches how elements are read from the MemRef. Elements whose corresponding
+ mask element is `0` are masked out and replaced with `padding`.
+
+ If either `padding` or `mask` are specified, both must be specified.
+
Example 1: Load an 8-bit element ZA tile with horizontal layout (default) from memory (ZA0.B).
```mlir
%tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
@@ -256,10 +285,16 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
```mlir
%tile = arm_sme.tile_load %base[%c0, %c0] layout<horizontal> : memref<?x?xi128>, vector<[1]x[1]xi128>
```
+
+ Example 4: Masked load of int 32-bit element ZA tile with horizontal layout (default) from memory.
+ ```mlir
+ %tile = arm_sme.tile_load %base[%c0, %c0], %pad, %mask : memref<?x?xf32>, vector<[4]x[4]xf32>
+ ```
}];
let arguments = (ins
Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
Variadic<Index>:$indices,
+ Optional<AnyType>:$padding, Optional<AnyVector>:$mask,
ArmSME_TileSliceLayoutAttr:$layout
);
let results = (outs SMETile:$result);
@@ -273,9 +308,20 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
}
}];
+ let builders = [
+ OpBuilder<(ins "VectorType":$resultType, "Value":$base,
+ "ValueRange":$indices, "TileSliceLayout":$layout), [{
+ build($_builder, $_state, resultType, base, indices, {}, {}, layout);
+ }]>,
+ OpBuilder<(ins "VectorType":$resultType, "Value":$base,
+ "ValueRange":$indices), [{
+ build($_builder, $_state, resultType, base, indices, {}, {}, {});
+ }]>,
+ ];
+
let assemblyFormat =
- "$base `[` $indices `]` (`layout` `` $layout^)? attr-dict "
- "`:` type($base) `,` type($result)";
+ "$base `[` $indices `]` (`,` $padding `,` $mask^)? (`layout` `` $layout^)?"
+ "attr-dict `:` type($base) `,` type($result)";
}
def TileStoreOp : ArmSME_Op<"tile_store"> {
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index 431009b1b9ede2f..25c62f78d843543 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -1,5 +1,9 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+//===----------------------------------------------------------------------===//
+// arm_sme.cast_tile_to_vector
+//===----------------------------------------------------------------------===//
+
// -----
func.func @arm_sme_cast_tile_to_vector__bad_tile_id_bitwidth(%tile_id : i8) -> vector<[8]x[8]xi16> {
@@ -48,6 +52,10 @@ func.func @arm_sme_cast_tile_to_vector_bad_shape(%tile_id : i8) -> vector<[4]x[1
return %0 : vector<[4]x[16]xi8>
}
+//===----------------------------------------------------------------------===//
+// arm_sme.cast_vector_to_tile
+//===----------------------------------------------------------------------===//
+
// -----
func.func @arm_sme_cast_vector_to_tile__bad_tile_id_bitwidth(%vector : vector<[1]x[1]xi128>) -> i32 {
@@ -64,6 +72,10 @@ func.func @arm_sme_cast_vector_to_tile__bad_rank_1d(%vector : vector<[16]xi8>) -
return %0 : i8
}
+//===----------------------------------------------------------------------===//
+// arm_sme.get_tile_id
+//===----------------------------------------------------------------------===//
+
// -----
func.func @arm_sme_get_tile_id__bad_type() -> i1 {
@@ -72,6 +84,10 @@ func.func @arm_sme_get_tile_id__bad_type() -> i1 {
return %0 : i1
}
+//===----------------------------------------------------------------------===//
+// arm_sme.move_vector_to_tile_slice
+//===----------------------------------------------------------------------===//
+
// -----
func.func @arm_sme_move_vector_to_tile_slice_i8__bad_vector_type(%vector : vector<[8]xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> vector<[16]x[16]xi8> {
@@ -90,6 +106,10 @@ func.func @arm_sme_move_vector_to_tile_slice_f32__bad_vector_type(%vector : vect
return %0 : vector<[4]x[4]xf32>
}
+//===----------------------------------------------------------------------===//
+// arm_sme.move_tile_slice_to_vector
+//===----------------------------------------------------------------------===//
+
// -----
func.func @arm_sme_move_tile_slice_to_vector__bad_result_type(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[2]xf64> {
@@ -97,3 +117,36 @@ func.func @arm_sme_move_tile_slice_to_vector__bad_result_type(%tile : vector<[4]
%0 = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[4]x[4]xf32>
return %0 : vector<[2]xf64>
}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.tile_load
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_tile_load__bad_padding_type(%src : memref<?x?xf64>, %pad : f32, %mask : vector<[2]x[2]xi1>) {
+ %c0 = arith.constant 0 : index
+ // expected-note at -2 {{prior use here}}
+ // expected-error at +1 {{use of value '%pad' expects
diff erent type than prior uses: 'f64' vs 'f32'}}
+ %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xf64>, vector<[2]x[2]xf64>
+ return
+}
+
+// -----
+
+func.func @arm_sme_tile_load__bad_mask_type(%src : memref<?x?xf64>, %pad : f64, %mask : vector<[4]x[4]xi1>) {
+ %c0 = arith.constant 0 : index
+ // expected-note at -2 {{prior use here}}
+ // expected-error at +1 {{use of value '%mask' expects
diff erent type than prior uses: 'vector<[2]x[2]xi1>' vs 'vector<[4]x[4]xi1>}}
+ %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xf64>, vector<[2]x[2]xf64>
+ return
+}
+
+// -----
+
+func.func @arm_sme_tile_load__pad_but_no_mask(%src : memref<?x?xf64>, %pad : f64) {
+ %c0 = arith.constant 0 : index
+ // expected-error at +1 {{op failed to verify that both `padding` and `mask` should be provided or neither}}
+ %tile = arm_sme.tile_load %src[%c0, %c0], %pad, : memref<?x?xf64>, vector<[2]x[2]xf64>
+ return
+}
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index e5ba81eff836027..6866137267dc66a 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -438,6 +438,16 @@ func.func @arm_sme_tile_load_ver_f64(%src : memref<?x?xf64>) {
// -----
+/// Padding and mask are optional
+func.func @arm_sme_tile_load_hor_pad_f64(%src : memref<?x?xf64>, %pad : f64, %mask : vector<[2]x[2]xi1>) {
+ // CHECK: arm_sme.tile_load %{{.*}}[{{.*}}], {{.*}}, {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
+ %c0 = arith.constant 0 : index
+ %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xf64>, vector<[2]x[2]xf64>
+ return
+}
+
+// -----
+
/// Layout is optional and horizontal is the default, verify it's still parsed.
func.func @arm_sme_tile_load_explicit_hor(%src : memref<?x?xi8>) {
// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi8>, vector<[16]x[16]xi8>
More information about the Mlir-commits
mailing list