[Mlir-commits] [mlir] [mlir][ArmSME] Audit ArmSME load/store ops (PR #139573)
Andrzej Warzyński
llvmlistbot at llvm.org
Tue May 13 01:25:18 PDT 2025
https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/139573
>From 836689d8794e288630534abbf143f0545960dbec Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 11 Apr 2025 16:03:07 +0000
Subject: [PATCH 1/2] [mlir][ArmSME] Audit ArmSME load/store ops
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
This patch updates the following arm_sme ops to require that input and
output element types match:
* `arm_sme.tile_load`, `arm_sme.tile_store`,
`arm_sme.tile_load_slice`, `arm_sme.tile_store_slice`.
In addition, it ensures that the base memref operand for `tile_load` and
`tile_store` is always rank-2, aligning with the semantics of Arm SME
tiles (always rank-2). This change is effectively a follow-up to #135151:
* "[mlir][vector] Tighten the semantics of vector.{load|store}"
The patch also updates `createLoadStoreForOverTileSlices` in
ArmSMEToSCF.cpp to fail when processing invalid tile stores like the
following:
```mlir
arm_sme.tile_store %arg0, %arg1[%c0] : memref<?x4xi8>, vector<[4]x[4]xi32>
```
This particular change fixes #118769. As noted in the TODO, we should
further extend op verification logic — I plan to address that in a
follow-up patch.
---
.../mlir/Dialect/ArmSME/IR/ArmSMEOps.td | 8 +-
.../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp | 13 ++-
mlir/test/Dialect/ArmSME/invalid.mlir | 81 +++++++++++++++++--
3 files changed, 88 insertions(+), 14 deletions(-)
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 6fd992afbf043..2f083b55d4904 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -317,6 +317,7 @@ def CopyTileOp : ArmSME_Op<"copy_tile", [
def TileLoadOp : ArmSME_Op<"tile_load", [
ArmSMETileOpInterface,
AttrSizedOperandSegments,
+ AllElementTypesMatch<["result", "base"]>,
OptionalTypesMatchWith<
"padding type matches element type of result",
"result", "padding",
@@ -369,7 +370,7 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
```
}];
let arguments = (ins
- Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
+ Arg<MemRefRankOf<[AnyType], [2]>, "the reference to load from", [MemRead]>:$base,
Variadic<Index>:$indices,
Optional<AnyType>:$padding, Optional<AnyVectorOfNonZeroRank>:$mask,
ArmSME_TileSliceLayoutAttr:$layout
@@ -407,6 +408,7 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
def TileStoreOp : ArmSME_Op<"tile_store", [
ArmSMETileOpInterface,
AttrSizedOperandSegments,
+ AllElementTypesMatch<["valueToStore", "base"]>,
HasMatchingMaskTypeConstraint<"valueToStore", "mask">,
]> {
let summary = "Tile store operation";
@@ -443,7 +445,7 @@ def TileStoreOp : ArmSME_Op<"tile_store", [
```
}];
let arguments = (ins SMETile:$valueToStore,
- Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
+ Arg<MemRefRankOf<[AnyType], [2]>, "the reference to store to", [MemWrite]>:$base,
Variadic<Index>:$indices, Optional<AnyVectorOfNonZeroRank>:$mask,
ArmSME_TileSliceLayoutAttr:$layout
);
@@ -473,6 +475,7 @@ def TileStoreOp : ArmSME_Op<"tile_store", [
def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
ArmSMETileOpInterface,
+ AllElementTypesMatch<["tile", "base"]>,
AllTypesMatch<["tile", "result"]>, TileSliceMaskConstraint<"result", "mask">
]> {
let summary = "Tile slice load and update operation";
@@ -535,6 +538,7 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [
ArmSMETileOpInterface,
+ AllElementTypesMatch<["tile", "base"]>,
TileSliceMaskConstraint<"tile", "mask">
]> {
let summary = "Tile slice store operation";
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 630414030d98b..458628c29c6ac 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -33,20 +33,15 @@ SmallVector<Value, 2> getMemrefIndices(ValueRange indices, unsigned rank,
Value tileSliceIndex,
Value tileSliceNumElts, Location loc,
PatternRewriter &rewriter) {
- assert((rank == 1 || rank == 2) && "memref has unexpected rank!");
+ assert(rank == 2 && "memref has unexpected rank!");
SmallVector<Value, 2> outIndices;
auto tileSliceOffset = tileSliceIndex;
- if (rank == 1)
- tileSliceOffset =
- rewriter.create<arith::MulIOp>(loc, tileSliceOffset, tileSliceNumElts);
auto baseIndexPlusTileSliceOffset =
rewriter.create<arith::AddIOp>(loc, indices[0], tileSliceOffset);
outIndices.push_back(baseIndexPlusTileSliceOffset);
-
- if (rank == 2)
- outIndices.push_back(indices[1]);
+ outIndices.push_back(indices[1]);
return outIndices;
}
@@ -60,6 +55,10 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
makeLoopBody) {
PatternRewriter::InsertionGuard guard(rewriter);
+ // TODO: This case should be captured and rejected by a verifier.
+ if (memrefIndices.size() != 2)
+ return rewriter.notifyMatchFailure(loc, "invalid number of indices");
+
auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
loc, arm_sme::getSMETileSliceMinNumElts(tileType.getElementType()));
auto vscale =
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index 700b2412ff7a7..8c5a098a0c785 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -50,7 +50,7 @@ func.func @arm_sme_get_tile__bad_shape(%tile_id : i8) -> vector<[4]x[16]xi8> {
// -----
-func.func @arm_sme_insert_tile_slice_i8__bad_vector_type(%vector : vector<[8]xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> vector<[16]x[16]xi8> {
+func.func @arm_sme_insert_tile_slice_i8__bad_vector_length(%vector : vector<[8]xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> vector<[16]x[16]xi8> {
%c0 = arith.constant 0 : index
// expected-error at +1 {{op failed to verify that type of 'vector' matches type of 'tile' slice}}
%0 = arm_sme.insert_tile_slice %vector, %tile[%tile_slice_index] : vector<[8]xi8> into vector<[16]x[16]xi8>
@@ -59,23 +59,40 @@ func.func @arm_sme_insert_tile_slice_i8__bad_vector_type(%vector : vector<[8]xi8
// -----
-func.func @arm_sme_insert_tile_slice_f32__bad_vector_type(%vector : vector<[8]xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[4]x[4]xf32> {
+func.func @arm_sme_insert_tile_slice_f32__bad_vector_length(%vector : vector<[8]xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[4]x[4]xf32> {
%c0 = arith.constant 0 : index
// expected-error at +1 {{op failed to verify that type of 'vector' matches type of 'tile' slice}}
%0 = arm_sme.insert_tile_slice %vector, %tile[%tile_slice_index] : vector<[8]xf32> into vector<[4]x[4]xf32>
return %0 : vector<[4]x[4]xf32>
}
+// -----
+
+func.func @arm_sme_insert_tile_slice__bad_element_type(%vector : vector<[4]xf64>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[4]x[4]xf32> {
+ %c0 = arith.constant 0 : index
+ // expected-error at +1 {{op failed to verify that type of 'vector' matches type of 'tile' slice}}
+ %0 = arm_sme.insert_tile_slice %vector, %tile[%tile_slice_index] : vector<[4]xf64> into vector<[4]x[4]xf32>
+ return %0 : vector<[4]x[4]xf32>
+}
+
//===----------------------------------------------------------------------===//
// arm_sme.extract_tile_slice
//===----------------------------------------------------------------------===//
// -----
-func.func @arm_sme_extract_tile_slice__bad_result_type(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[2]xf64> {
+func.func @arm_sme_extract_tile_slice__bad_result_length(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[2]xf32> {
+ // expected-error at +1 {{op failed to verify that type of 'result' matches type of 'tile' slice}}
+ %0 = arm_sme.extract_tile_slice %tile[%tile_slice_index] : vector<[2]xf32> from vector<[4]x[4]xf32>
+ return %0 : vector<[2]xf32>
+}
+
+// -----
+
+func.func @arm_sme_extract_tile_slice__bad_result_element_type(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[4]xf64> {
// expected-error at +1 {{op failed to verify that type of 'result' matches type of 'tile' slice}}
- %0 = arm_sme.extract_tile_slice %tile[%tile_slice_index] : vector<[2]xf64> from vector<[4]x[4]xf32>
- return %0 : vector<[2]xf64>
+ %0 = arm_sme.extract_tile_slice %tile[%tile_slice_index] : vector<[4]xf64> from vector<[4]x[4]xf32>
+ return %0 : vector<[4]xf64>
}
//===----------------------------------------------------------------------===//
@@ -111,6 +128,24 @@ func.func @arm_sme_tile_load__pad_but_no_mask(%src : memref<?x?xf64>, %pad : f64
return
}
+// -----
+
+func.func @arm_sme_tile_load__bad_memref_rank(%src : memref<?xf64>, %pad : f64) {
+ %c0 = arith.constant 0 : index
+ // expected-error at +1 {{op operand #0 must be 2D memref of any type values, but got 'memref<?xf64>'}}
+ %tile = arm_sme.tile_load %src[%c0], %pad, : memref<?xf64>, vector<[2]x[2]xf64>
+ return
+}
+
+// -----
+
+func.func @arm_sme_tile_load__bad_element_type(%src : memref<?x?xi32>) {
+ %c0 = arith.constant 0 : index
+ // expected-error at +1 {{failed to verify that all of {result, base} have same element type}}
+ %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi32>, vector<[16]x[16]xi8>
+ return
+}
+
//===----------------------------------------------------------------------===//
// arm_sme.load_tile_slice
//===----------------------------------------------------------------------===//
@@ -124,6 +159,15 @@ func.func @arm_sme_load_tile_slice__bad_mask_type(%src : memref<?x?xi8>, %mask :
return
}
+// -----
+
+func.func @arm_sme_load_tile_slice__bad_element_type(%src : memref<?x?xi32>, %mask : vector<[16]xi1>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
+ %c0 = arith.constant 0 : index
+ // expected-error at +1 {{op failed to verify that all of {tile, base} have same element type}}
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi32>, vector<[16]xi1>, vector<[16]x[16]xi8>
+ return
+}
+
//===----------------------------------------------------------------------===//
// arm_sme.tile_store
//===----------------------------------------------------------------------===//
@@ -138,6 +182,24 @@ func.func @arm_sme_tile_store__bad_mask_type(%tile : vector<[16]x[16]xi8>, %mask
return
}
+// -----
+
+func.func @arm_sme_tile_store__bad_memref_rank(%tile : vector<[16]x[16]xi8>, %dest : memref<?xi8>) {
+ %c0 = arith.constant 0 : index
+ // expected-error at +1 {{op operand #1 must be 2D memref of any type values, but got 'memref<?xi8>'}}
+ arm_sme.tile_store %tile, %dest[%c0] : memref<?xi8>, vector<[16]x[16]xi8>
+ return
+}
+
+// -----
+
+func.func @arm_sme_tile_store__bad_element_type(%tile : vector<[16]x[16]xi8>, %dest : memref<?x?xi32>) {
+ %c0 = arith.constant 0 : index
+ // expected-error at +1 {{op failed to verify that all of {valueToStore, base} have same element type}}
+ arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi32>, vector<[16]x[16]xi8>
+ return
+}
+
//===----------------------------------------------------------------------===//
// arm_sme.store_tile_slice
//===----------------------------------------------------------------------===//
@@ -152,6 +214,15 @@ func.func @arm_sme_store_tile_slice__bad_mask_type(%tile : vector<[16]x[16]xi8>,
return
}
+// -----
+
+func.func @arm_sme_store_tile_slice__bad_element_type(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %mask : vector<[16]xi1>, %dest : memref<?x?xi32>) -> () {
+ %c0 = arith.constant 0 : index
+ // expected-error at +1 {{op failed to verify that all of {tile, base} have same element type}}
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0, %c0] : memref<?x?xi32>, vector<[16]xi1>, vector<[16]x[16]xi8>
+ return
+}
+
//===----------------------------------------------------------------------===//
// arm_sme.outerproduct
//===----------------------------------------------------------------------===//
>From 579bff7bf3019f0c164cd22e6000f87fd1f3433d Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 13 May 2025 08:24:49 +0000
Subject: [PATCH 2/2] fixup! [mlir][ArmSME] Audit ArmSME load/store ops
Update test
---
.../Dialect/Vector/CPU/ArmSME/load-vertical.mlir | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-vertical.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-vertical.mlir
index b7144be08a853..8d4b4a07994e2 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-vertical.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-vertical.mlir
@@ -17,7 +17,7 @@ func.func @entry() {
%za_s_size = arith.muli %svl_s, %svl_s : index
// Allocate memory.
- %mem1 = memref.alloca(%za_s_size) : memref<?xi32>
+ %mem1 = memref.alloca(%svl_s, %svl_s) : memref<?x?xi32>
// Fill each "row" of "mem1" with row number.
//
@@ -29,15 +29,15 @@ func.func @entry() {
// 3, 3, 3, 3
//
%init_0 = arith.constant 0 : i32
- scf.for %i = %c0 to %za_s_size step %svl_s iter_args(%val = %init_0) -> (i32) {
+ scf.for %i = %c0 to %svl_s step %c1 iter_args(%val = %init_0) -> (i32) {
%splat_val = vector.broadcast %val : i32 to vector<[4]xi32>
- vector.store %splat_val, %mem1[%i] : memref<?xi32>, vector<[4]xi32>
+ vector.store %splat_val, %mem1[%i, %c0] : memref<?x?xi32>, vector<[4]xi32>
%val_next = arith.addi %val, %c1_i32 : i32
scf.yield %val_next : i32
}
// Load tile from "mem1" vertically.
- %0 = arm_sme.tile_load %mem1[%c0, %c0] layout<vertical> : memref<?xi32>, vector<[4]x[4]xi32>
+ %0 = arm_sme.tile_load %mem1[%c0, %c0] layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
// 1. ORIGINAL HORIZONTAL LAYOUT
// Dump "mem1". The smallest SVL is 128-bits so the tile will be at least
@@ -50,8 +50,8 @@ func.func @entry() {
// CHECK-NEXT: ( 3, 3, 3, 3
// CHECK: TILE END
vector.print str "TILE BEGIN\n"
- scf.for %i = %c0 to %za_s_size step %svl_s {
- %tileslice = vector.load %mem1[%i] : memref<?xi32>, vector<[4]xi32>
+ scf.for %i = %c0 to %svl_s step %c1 {
+ %tileslice = vector.load %mem1[%i, %c0] : memref<?x?xi32>, vector<[4]xi32>
vector.print %tileslice : vector<[4]xi32>
}
vector.print str "TILE END\n"
More information about the Mlir-commits
mailing list