[Mlir-commits] [mlir] [mlir][ArmSME] Add optional padding and mask operands to tile_load (PR #69195)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Oct 16 05:31:29 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Cullen Rhodes (c-rhodes)

<details>
<summary>Changes</summary>

Padding and mask are optional, but if one is specified both must be specified. This is consistent with vector.transfer_read.

---
Full diff: https://github.com/llvm/llvm-project/pull/69195.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td (+47-3) 
- (modified) mlir/test/Dialect/ArmSME/invalid.mlir (+44) 
- (modified) mlir/test/Dialect/ArmSME/roundtrip.mlir (+10) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index dab54b63d8d22be..6f6b54aad0058e5 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -231,7 +231,24 @@ def ZeroOp : ArmSME_Op<"zero", [Pure]> {
   let assemblyFormat = "attr-dict `:` type($res)";
 }
 
-def TileLoadOp : ArmSME_Op<"tile_load"> {
+def TileLoadOp : ArmSME_Op<"tile_load", [
+  AttrSizedOperandSegments,
+  TypesMatchWith<
+    "padding type matches element type of result (if present)",
+    "result", "padding",
+    "::llvm::cast<VectorType>($_self).getElementType()",
+    "!getPadding() || std::equal_to<>()"
+  >,
+  TypesMatchWith<
+    "mask has i1 element type and same shape as result (if present)",
+    "result", "mask",
+    "VectorType("
+      "VectorType::Builder("
+        "::llvm::cast<mlir::VectorType>($_self)"
+      ").setElementType(IntegerType::get($_self.getContext(), 1)))",
+    "!getMask() || std::equal_to<>()"
+  >
+]> {
   let summary = "Tile load operation";
   let description = [{
     Loads a 2D SME "virtual tile" from memory defined by a base and indices,
@@ -242,6 +259,16 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
     dimensions, since the operation is scalable, and the element type must be a
     scalar that matches the element type of the result.
 
+    An optional SSA value `padding` of the same elemental type as the MemRef is
+    provided to specify a fallback value in the case of masking.
+
+    An optional SSA value `mask` may be specified to mask out elements read
+    from the MemRef. The `mask` type is an `i1` vector with a shape that
+    matches how elements are read from the MemRef. Elements whose corresponding
+    mask element is `0` are masked out and replaced with `padding`.
+
+    If either `padding` or `mask` are specified, both must be specified.
+
     Example 1: Load an 8-bit element ZA tile with horizontal layout (default) from memory (ZA0.B).
     ```mlir
     %tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
@@ -256,10 +283,16 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
     ```mlir
     %tile = arm_sme.tile_load %base[%c0, %c0] layout<horizontal> : memref<?x?xi128>, vector<[1]x[1]xi128>
     ```
+
+    Example 4: Masked load of int 32-bit element ZA tile with horizontal layout (default) from memory.
+    ```mlir
+    %tile = arm_sme.tile_load %base[%c0, %c0], %pad, %mask : memref<?x?xf32>, vector<[4]x[4]xf32>
+    ```
   }];
   let arguments = (ins
     Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
     Variadic<Index>:$indices,
+    Optional<AnyType>:$padding, Optional<AnyVector>:$mask,
     ArmSME_TileSliceLayoutAttr:$layout
   );
   let results = (outs SMETile:$result);
@@ -273,9 +306,20 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
     }
   }];
 
+  let builders = [
+    OpBuilder<(ins "VectorType":$resultType, "Value":$base,
+                   "ValueRange":$indices, "TileSliceLayout":$layout), [{
+      build($_builder, $_state, resultType, base, indices, {}, {}, layout);
+    }]>,
+    OpBuilder<(ins "VectorType":$resultType, "Value":$base,
+                   "ValueRange":$indices), [{
+      build($_builder, $_state, resultType, base, indices, {}, {}, {});
+    }]>,
+  ];
+
   let assemblyFormat =
-    "$base `[` $indices `]` (`layout` `` $layout^)? attr-dict "
-      "`:` type($base) `,` type($result)";
+    "$base `[` $indices `]` (`,` $padding `,` $mask^)? (`layout` `` $layout^)?"
+      "attr-dict `:` type($base) `,` type($result)";
 }
 
 def TileStoreOp : ArmSME_Op<"tile_store"> {
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index 431009b1b9ede2f..9229f0415c076c3 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -1,5 +1,9 @@
 // RUN: mlir-opt %s -split-input-file -verify-diagnostics
 
+//===----------------------------------------------------------------------===//
+// arm_sme.cast_tile_to_vector
+//===----------------------------------------------------------------------===//
+
 // -----
 
 func.func @arm_sme_cast_tile_to_vector__bad_tile_id_bitwidth(%tile_id : i8) -> vector<[8]x[8]xi16> {
@@ -48,6 +52,10 @@ func.func @arm_sme_cast_tile_to_vector_bad_shape(%tile_id : i8) -> vector<[4]x[1
   return %0 : vector<[4]x[16]xi8>
 }
 
+//===----------------------------------------------------------------------===//
+// arm_sme.cast_vector_to_tile
+//===----------------------------------------------------------------------===//
+
 // -----
 
 func.func @arm_sme_cast_vector_to_tile__bad_tile_id_bitwidth(%vector : vector<[1]x[1]xi128>) -> i32 {
@@ -64,6 +72,10 @@ func.func @arm_sme_cast_vector_to_tile__bad_rank_1d(%vector : vector<[16]xi8>) -
   return %0 : i8
 }
 
+//===----------------------------------------------------------------------===//
+// arm_sme.get_tile_id
+//===----------------------------------------------------------------------===//
+
 // -----
 
 func.func @arm_sme_get_tile_id__bad_type() -> i1 {
@@ -72,6 +84,10 @@ func.func @arm_sme_get_tile_id__bad_type() -> i1 {
   return %0 : i1
 }
 
+//===----------------------------------------------------------------------===//
+// arm_sme.move_vector_to_tile_slice
+//===----------------------------------------------------------------------===//
+
 // -----
 
 func.func @arm_sme_move_vector_to_tile_slice_i8__bad_vector_type(%vector : vector<[8]xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> vector<[16]x[16]xi8> {
@@ -90,6 +106,10 @@ func.func @arm_sme_move_vector_to_tile_slice_f32__bad_vector_type(%vector : vect
   return %0 : vector<[4]x[4]xf32>
 }
 
+//===----------------------------------------------------------------------===//
+// arm_sme.move_tile_slice_to_vector
+//===----------------------------------------------------------------------===//
+
 // -----
 
 func.func @arm_sme_move_tile_slice_to_vector__bad_result_type(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[2]xf64> {
@@ -97,3 +117,27 @@ func.func @arm_sme_move_tile_slice_to_vector__bad_result_type(%tile : vector<[4]
   %0 = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[4]x[4]xf32>
   return %0 : vector<[2]xf64>
 }
+
+//===----------------------------------------------------------------------===//
+// arm_sme.tile_load
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_tile_load__bad_padding_type(%src : memref<?x?xf64>, %pad : f32, %mask : vector<[2]x[2]xi1>) {
+  %c0 = arith.constant 0 : index
+  // expected-note at -2 {{prior use here}}
+  // expected-error at +1 {{use of value '%pad' expects different type than prior uses: 'f64' vs 'f32'}}
+  %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xf64>, vector<[2]x[2]xf64>
+  return
+}
+
+// -----
+
+func.func @arm_sme_tile_load__bad_mask_type(%src : memref<?x?xf64>, %pad : f64, %mask : vector<[4]x[4]xi1>) {
+  %c0 = arith.constant 0 : index
+  // expected-note at -2 {{prior use here}}
+  // expected-error at +1 {{use of value '%mask' expects different type than prior uses: 'vector<[2]x[2]xi1>' vs 'vector<[4]x[4]xi1>}}
+  %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xf64>, vector<[2]x[2]xf64>
+  return
+}
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index 427154158e797fd..f6459f085843655 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -438,6 +438,16 @@ func.func @arm_sme_tile_load_ver_f64(%src : memref<?x?xf64>) {
 
 // -----
 
+/// Padding and mask are optional
+func.func @arm_sme_tile_load_hor_pad_f64(%src : memref<?x?xf64>, %pad : f64, %mask : vector<[2]x[2]xi1>) {
+  // CHECK: arm_sme.tile_load %{{.*}}[{{.*}}], {{.*}}, {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
+  %c0 = arith.constant 0 : index
+  %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xf64>, vector<[2]x[2]xf64>
+  return
+}
+
+// -----
+
 /// Layout is optional and horizontal is the default, verify it's still parsed.
 func.func @arm_sme_tile_load_explicit_hor(%src : memref<?x?xi8>) {
   // CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi8>, vector<[16]x[16]xi8>

``````````

</details>


https://github.com/llvm/llvm-project/pull/69195


More information about the Mlir-commits mailing list