[Mlir-commits] [mlir] [mlir][ArmSME] Add optional mask operand to tile_store (PR #70657)

Cullen Rhodes llvmlistbot at llvm.org
Tue Oct 31 07:32:15 PDT 2023


https://github.com/c-rhodes updated https://github.com/llvm/llvm-project/pull/70657

>From 27b129f8a55ccc9583743c8fc6de0b977c62588f Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Sun, 15 Oct 2023 13:16:54 +0000
Subject: [PATCH 1/2] [mlir][ArmSME] Add optional mask operand to tile_store

---
 .../mlir/Dialect/ArmSME/IR/ArmSMEOps.td       | 44 ++++++++++++++-----
 .../VectorToArmSME/VectorToArmSME.cpp         |  4 +-
 mlir/test/Dialect/ArmSME/invalid.mlir         | 14 ++++++
 mlir/test/Dialect/ArmSME/roundtrip.mlir       |  9 ++++
 .../Dialect/ArmSME/vector-ops-to-sme.mlir     | 14 ++++++
 5 files changed, 71 insertions(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 37a2257a0015ce7..51ee1af83a9192e 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -60,6 +60,12 @@ def TileElementWidthMatchesTileID : TypesMatchWith<
                   "::llvm::cast<VectorType>($_self).getElementType())"
                   ".getWidth())">;
 
+class HasMatchingMaskTypeConstraint<string vector, string mask> :
+  OptionalTypesMatchWith<
+    mask # " has i1 element type and same shape as " # vector,
+    vector, mask,
+    "::llvm::cast<mlir::VectorType>($_self).cloneWith({}, IntegerType::get($_ctxt, 1))">;
+
 //===----------------------------------------------------------------------===//
 // ArmSME attr definitions
 //===----------------------------------------------------------------------===//
@@ -259,14 +265,7 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
     "result", "padding",
     "::llvm::cast<VectorType>($_self).getElementType()"
   >,
-  OptionalTypesMatchWith<
-    "mask has i1 element type and same shape as result",
-    "result", "mask",
-    "VectorType("
-      "VectorType::Builder("
-        "::llvm::cast<mlir::VectorType>($_self)"
-      ").setElementType(IntegerType::get($_self.getContext(), 1)))"
-  >,
+  HasMatchingMaskTypeConstraint<"result", "mask">,
   PredOpTrait<
     "both `padding` and `mask` should be provided or neither",
     CPred<"bool(getPadding()) == bool(getMask())">
@@ -345,7 +344,10 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
       "attr-dict `:` type($base) `,` type($result)";
 }
 
-def TileStoreOp : ArmSME_Op<"tile_store"> {
+def TileStoreOp : ArmSME_Op<"tile_store", [
+  AttrSizedOperandSegments,
+  HasMatchingMaskTypeConstraint<"valueToStore", "mask">,
+]> {
   let summary = "Tile store operation";
   let description = [{
     Stores a 2D SME "virtual tile" to memory defined by a base and indices,
@@ -356,6 +358,11 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
     rank 2 with dynamic dimensions, since the operation is scalable, and the
     element type must be a scalar that matches the element type of the result.
 
+    An optional SSA value `mask` may be specified to mask out elements written
+    to the MemRef. The `mask` type is an `i1` vector of the same shape as the
+    vector type that matches how elements are written into the MemRef. Elements
+    whose corresponding mask element is `0` are masked out.
+
     Example 1: Store an 8-bit element ZA tile with horizontal (default) layout to memory (ZA0.B).
     ```mlir
     arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
@@ -370,10 +377,16 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
     ```mlir
     arm_sme.tile_store %tile, %base[%c0, %c0] layout<horizontal> : vector<[1]x[1]xi128>, memref<?x?xi128>
     ```
+
+    Example 4: Masked store a int 32-bit element ZA tile with vertical layout to memory.
+    ```mlir
+    arm_sme.tile_store %tile, %base[%c0, %c0], %mask layout<vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
+    ```
   }];
   let arguments = (ins SMETile:$valueToStore,
     Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
-    Variadic<Index>:$indices, ArmSME_TileSliceLayoutAttr:$layout
+    Variadic<Index>:$indices, Optional<AnyVector>:$mask,
+    ArmSME_TileSliceLayoutAttr:$layout
   );
   let extraClassDeclaration = [{
     MemRefType getMemRefType() {
@@ -384,9 +397,16 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
     }
   }];
 
+  let builders = [
+    OpBuilder<(ins "Value":$valueToStore, "Value":$base,
+                   "ValueRange":$indices), [{
+      build($_builder, $_state, valueToStore, base, indices, {});
+    }]>,
+  ];
+
   let assemblyFormat =
-    "$valueToStore `,` $base `[` $indices `]` (`layout` `` $layout^)? attr-dict "
-      "`:` type($base) `,` type($valueToStore)";
+    "$valueToStore `,` $base `[` $indices `]` (`,` $mask^)? (`layout` `` $layout^)?"
+      "attr-dict `:` type($base) `,` type($valueToStore)";
 }
 
 def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index b60c21e2ced7a8f..005dd546bf1632b 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -144,8 +144,8 @@ struct TransferWriteToArmSMELowering
       return failure();
 
     rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
-        writeOp, writeOp.getVector(), writeOp.getSource(),
-        writeOp.getIndices());
+        writeOp, writeOp.getVector(), writeOp.getSource(), writeOp.getIndices(),
+        writeOp.getMask());
     return success();
   }
 };
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index 1d6386bbf3828fa..588b8e891faddef 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -164,6 +164,20 @@ func.func @arm_sme_load_tile_slice__bad_mask_type(%src : memref<?x?xi8>, %mask :
   return
 }
 
+//===----------------------------------------------------------------------===//
+// arm_sme.tile_store
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_tile_store__bad_mask_type(%tile : vector<[16]x[16]xi8>, %mask : vector<[1]x[1]xi1>, %dest : memref<?x?xi8>) {
+  %c0 = arith.constant 0 : index
+  // expected-note at -2 {{prior use here}}
+  // expected-error at +1 {{use of value '%mask' expects different type than prior uses: 'vector<[16]x[16]xi1>' vs 'vector<[1]x[1]xi1>}}
+  arm_sme.tile_store %tile, %dest[%c0, %c0], %mask : memref<?x?xi8>, vector<[16]x[16]xi8>
+  return
+}
+
 //===----------------------------------------------------------------------===//
 // arm_sme.outerproduct
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index 6d0aa48015c145a..8206bd80e3eb4cf 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -624,6 +624,15 @@ func.func @arm_sme_tile_store_ver_f64(%tile : vector<[2]x[2]xf64>, %dest : memre
 
 // -----
 
+func.func @arm_sme_tile_store_with_mask_ver_f32(%tile : vector<[4]x[4]xf32>, %dest : memref<?x?xf32>, %mask : vector<[4]x[4]xi1>) {
+  // CHECK: arm_sme.tile_store {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+  %c0 = arith.constant 0 : index
+  arm_sme.tile_store %tile, %dest[%c0, %c0], %mask layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+  return
+}
+
+// -----
+
 /// Layout is optional and horizontal is the default, verify it's still parsed.
 func.func @arm_sme_tile_store_ver_i8(%tile : vector<[16]x[16]xi8>, %dest : memref<?x?xi8>) {
   // CHECK: arm_sme.tile_store %{{.*}}[{{.*}}] : memref<?x?xi8>, vector<[16]x[16]xi8>
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index 9eb7cd143e5b5ea..67fc016a7f4a985 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -315,6 +315,20 @@ func.func @transfer_write_2d_f64(%vector : vector<[2]x[2]xf64>, %dest : memref<?
 
 // -----
 
+// CHECK-LABEL: func.func @transfer_write_2d_with_mask_f64(
+// CHECK-SAME:                                             %[[VECTOR:.*]]: vector<[2]x[2]xf64>,
+// CHECK-SAME:                                             %[[DEST:.*]]: memref<?x?xf64>,
+// CHECK-SAME:                                             %[[MASK:.*]]: vector<[2]x[2]xi1>) {
+// CHECK:         %[[C0:.*]] = arith.constant 0 : index
+// CHECK:         arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]], {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
+func.func @transfer_write_2d_with_mask_f64(%vector : vector<[2]x[2]xf64>, %dest : memref<?x?xf64>, %mask : vector<[2]x[2]xi1>) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vector, %dest[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[2]x[2]xf64>, memref<?x?xf64>
+  return
+}
+
+// -----
+
 // The following tests check the 'vector.transfer_write' -> 'arm_sme.intr.zero'
 // lowering only occurs for vector types of correct rank, shape, element size
 // and number of scalable dims.

>From f1cf692ec77e57aeac5af1f090d0ef304320b5d2 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Tue, 31 Oct 2023 14:20:58 +0000
Subject: [PATCH 2/2] Rebase and de-duplicate HasMatchingMaskTypeConstraint

---
 mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td | 10 ++--------
 1 file changed, 2 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 51ee1af83a9192e..a6af292524dc5fc 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -615,12 +615,6 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
   }];
 }
 
-class HasMatchingMaskTypeConstraint<string operand> :
-  OptionalTypesMatchWith<
-    "shape of `" # operand #  "Mask` matches `" # operand # "`",
-    operand, operand # "Mask",
-    "::llvm::cast<mlir::VectorType>($_self).cloneWith({}, IntegerType::get($_ctxt, 1))">;
-
 class OuterProductResultTileTypeConstraint<string operand> :
   OptionalTypesMatchWith<operand # "type is derived from `lhs` and `rhs`",
     "lhs", operand,
@@ -635,8 +629,8 @@ def OuterProductOp :
   ArmSME_Op<"outerproduct", [Pure,
     AttrSizedOperandSegments,
     AllTypesMatch<["lhs", "rhs"]>,
-    HasMatchingMaskTypeConstraint<"lhs">,
-    HasMatchingMaskTypeConstraint<"rhs">,
+    HasMatchingMaskTypeConstraint<"lhs", "lhsMask">,
+    HasMatchingMaskTypeConstraint<"rhs", "rhsMask">,
     PredOpTrait<
       "both `lhsMask` and `rhsMask` should be provided or neither",
       CPred<"bool(getLhsMask()) == bool(getRhsMask())">>,



More information about the Mlir-commits mailing list