[Mlir-commits] [mlir] [mlir][ArmSME] Add optional mask operand to tile_store (PR #70657)
Cullen Rhodes
llvmlistbot at llvm.org
Mon Oct 30 06:09:31 PDT 2023
https://github.com/c-rhodes created https://github.com/llvm/llvm-project/pull/70657
None
>From 1cfe5d2ac5b9316f2fa07ba0cdd67d5147fc3b8b Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Sun, 15 Oct 2023 13:16:54 +0000
Subject: [PATCH] [mlir][ArmSME] Add optional mask operand to tile_store
---
.../mlir/Dialect/ArmSME/IR/ArmSMEOps.td | 44 ++++++++++++++-----
.../VectorToArmSME/VectorToArmSME.cpp | 4 +-
mlir/test/Dialect/ArmSME/invalid.mlir | 14 ++++++
mlir/test/Dialect/ArmSME/roundtrip.mlir | 9 ++++
.../Dialect/ArmSME/vector-ops-to-sme.mlir | 14 ++++++
5 files changed, 71 insertions(+), 14 deletions(-)
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index b30d0fdb866bd23..7e46a2ce4baf897 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -60,6 +60,12 @@ def TileElementWidthMatchesTileID : TypesMatchWith<
"::llvm::cast<VectorType>($_self).getElementType())"
".getWidth())">;
+class HasMatchingMaskTypeConstraint<string vector, string mask> :
+ OptionalTypesMatchWith<
+ mask # " has i1 element type and same shape as " # vector,
+ vector, mask,
+ "::llvm::cast<mlir::VectorType>($_self).cloneWith({}, IntegerType::get($_ctxt, 1))">;
+
//===----------------------------------------------------------------------===//
// ArmSME attr definitions
//===----------------------------------------------------------------------===//
@@ -238,14 +244,7 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
"result", "padding",
"::llvm::cast<VectorType>($_self).getElementType()"
>,
- OptionalTypesMatchWith<
- "mask has i1 element type and same shape as result",
- "result", "mask",
- "VectorType("
- "VectorType::Builder("
- "::llvm::cast<mlir::VectorType>($_self)"
- ").setElementType(IntegerType::get($_self.getContext(), 1)))"
- >,
+ HasMatchingMaskTypeConstraint<"result", "mask">,
PredOpTrait<
"both `padding` and `mask` should be provided or neither",
CPred<"bool(getPadding()) == bool(getMask())">
@@ -324,7 +323,10 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
"attr-dict `:` type($base) `,` type($result)";
}
-def TileStoreOp : ArmSME_Op<"tile_store"> {
+def TileStoreOp : ArmSME_Op<"tile_store", [
+ AttrSizedOperandSegments,
+ HasMatchingMaskTypeConstraint<"valueToStore", "mask">,
+]> {
let summary = "Tile store operation";
let description = [{
Stores a 2D SME "virtual tile" to memory defined by a base and indices,
@@ -335,6 +337,11 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
rank 2 with dynamic dimensions, since the operation is scalable, and the
element type must be a scalar that matches the element type of the result.
+ An optional SSA value `mask` may be specified to mask out elements written
+ to the MemRef. The `mask` type is an `i1` vector of the same shape as the
+ vector type that matches how elements are written into the MemRef. Elements
+ whose corresponding mask element is `0` are masked out.
+
Example 1: Store an 8-bit element ZA tile with horizontal (default) layout to memory (ZA0.B).
```mlir
arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
@@ -349,10 +356,16 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
```mlir
arm_sme.tile_store %tile, %base[%c0, %c0] layout<horizontal> : vector<[1]x[1]xi128>, memref<?x?xi128>
```
+
+ Example 4: Masked store a int 32-bit element ZA tile with vertical layout to memory.
+ ```mlir
+ arm_sme.tile_store %tile, %base[%c0, %c0], %mask layout<vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
+ ```
}];
let arguments = (ins SMETile:$valueToStore,
Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
- Variadic<Index>:$indices, ArmSME_TileSliceLayoutAttr:$layout
+ Variadic<Index>:$indices, Optional<AnyVector>:$mask,
+ ArmSME_TileSliceLayoutAttr:$layout
);
let extraClassDeclaration = [{
MemRefType getMemRefType() {
@@ -363,9 +376,16 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
}
}];
+ let builders = [
+ OpBuilder<(ins "Value":$valueToStore, "Value":$base,
+ "ValueRange":$indices), [{
+ build($_builder, $_state, valueToStore, base, indices, {});
+ }]>,
+ ];
+
let assemblyFormat =
- "$valueToStore `,` $base `[` $indices `]` (`layout` `` $layout^)? attr-dict "
- "`:` type($base) `,` type($valueToStore)";
+ "$valueToStore `,` $base `[` $indices `]` (`,` $mask^)? (`layout` `` $layout^)?"
+ "attr-dict `:` type($base) `,` type($valueToStore)";
}
def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index d06eb4f5b01c950..b5d8a956253b570 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -144,8 +144,8 @@ struct TransferWriteToArmSMELowering
return failure();
rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
- writeOp, writeOp.getVector(), writeOp.getSource(),
- writeOp.getIndices());
+ writeOp, writeOp.getVector(), writeOp.getSource(), writeOp.getIndices(),
+ writeOp.getMask());
return success();
}
};
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index 25c62f78d843543..d2775aa9f3610b2 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -150,3 +150,17 @@ func.func @arm_sme_tile_load__pad_but_no_mask(%src : memref<?x?xf64>, %pad : f64
%tile = arm_sme.tile_load %src[%c0, %c0], %pad, : memref<?x?xf64>, vector<[2]x[2]xf64>
return
}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.tile_store
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_tile_store__bad_mask_type(%tile : vector<[16]x[16]xi8>, %mask : vector<[1]x[1]xi1>, %dest : memref<?x?xi8>) {
+ %c0 = arith.constant 0 : index
+ // expected-note at -2 {{prior use here}}
+ // expected-error at +1 {{use of value '%mask' expects different type than prior uses: 'vector<[16]x[16]xi1>' vs 'vector<[1]x[1]xi1>}}
+ arm_sme.tile_store %tile, %dest[%c0, %c0], %mask : memref<?x?xi8>, vector<[16]x[16]xi8>
+ return
+}
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index 6866137267dc66a..eeabfa1409a644c 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -624,6 +624,15 @@ func.func @arm_sme_tile_store_ver_f64(%tile : vector<[2]x[2]xf64>, %dest : memre
// -----
+func.func @arm_sme_tile_store_with_mask_ver_f32(%tile : vector<[4]x[4]xf32>, %dest : memref<?x?xf32>, %mask : vector<[4]x[4]xi1>) {
+ // CHECK: arm_sme.tile_store {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+ %c0 = arith.constant 0 : index
+ arm_sme.tile_store %tile, %dest[%c0, %c0], %mask layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+ return
+}
+
+// -----
+
/// Layout is optional and horizontal is the default, verify it's still parsed.
func.func @arm_sme_tile_store_ver_i8(%tile : vector<[16]x[16]xi8>, %dest : memref<?x?xi8>) {
// CHECK: arm_sme.tile_store %{{.*}}[{{.*}}] : memref<?x?xi8>, vector<[16]x[16]xi8>
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index 455b47a83e28f43..e5ebdf223f315d6 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -315,6 +315,20 @@ func.func @transfer_write_2d_f64(%vector : vector<[2]x[2]xf64>, %dest : memref<?
// -----
+// CHECK-LABEL: func.func @transfer_write_2d_with_mask_f64(
+// CHECK-SAME: %[[VECTOR:.*]]: vector<[2]x[2]xf64>,
+// CHECK-SAME: %[[DEST:.*]]: memref<?x?xf64>,
+// CHECK-SAME: %[[MASK:.*]]: vector<[2]x[2]xi1>) {
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]], {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
+func.func @transfer_write_2d_with_mask_f64(%vector : vector<[2]x[2]xf64>, %dest : memref<?x?xf64>, %mask : vector<[2]x[2]xi1>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vector, %dest[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[2]x[2]xf64>, memref<?x?xf64>
+ return
+}
+
+// -----
+
// The following tests check the 'vector.transfer_write' -> 'arm_sme.intr.zero'
// lowering only occurs for vector types of correct rank, shape, element size
// and number of scalable dims.
More information about the Mlir-commits
mailing list