[Mlir-commits] [mlir] [mlir][ArmSME] Audit ArmSME load/store ops (PR #139573)

Andrzej Warzyński llvmlistbot at llvm.org
Tue May 13 01:25:18 PDT 2025


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/139573

>From 836689d8794e288630534abbf143f0545960dbec Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 11 Apr 2025 16:03:07 +0000
Subject: [PATCH 1/2] [mlir][ArmSME] Audit ArmSME load/store ops
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This patch updates the following arm_sme ops to require that input and
output element types match:
  * `arm_sme.tile_load`, `arm_sme.tile_store`,
    `arm_sme.tile_load_slice`, `arm_sme.tile_store_slice`.

In addition, it ensures that the base memref operand for `tile_load` and
`tile_store` is always rank-2, aligning with the semantics of Arm SME
tiles (always rank-2). This change is effectively a follow-up to #135151:

  * "[mlir][vector] Tighten the semantics of vector.{load|store}"

The patch also updates `createLoadStoreForOverTileSlices` in
ArmSMEToSCF.cpp to fail when processing invalid tile stores like the
following:

```mlir
arm_sme.tile_store %arg0, %arg1[%c0] : memref<?x4xi8>, vector<[4]x[4]xi32>
```

This particular change fixes #118769. As noted in the TODO, we should
further extend op verification logic — I plan to address that in a
follow-up patch.
---
 .../mlir/Dialect/ArmSME/IR/ArmSMEOps.td       |  8 +-
 .../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp    | 13 ++-
 mlir/test/Dialect/ArmSME/invalid.mlir         | 81 +++++++++++++++++--
 3 files changed, 88 insertions(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 6fd992afbf043..2f083b55d4904 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -317,6 +317,7 @@ def CopyTileOp : ArmSME_Op<"copy_tile", [
 def TileLoadOp : ArmSME_Op<"tile_load", [
   ArmSMETileOpInterface,
   AttrSizedOperandSegments,
+  AllElementTypesMatch<["result", "base"]>,
   OptionalTypesMatchWith<
     "padding type matches element type of result",
     "result", "padding",
@@ -369,7 +370,7 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
     ```
   }];
   let arguments = (ins
-    Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
+    Arg<MemRefRankOf<[AnyType], [2]>, "the reference to load from", [MemRead]>:$base,
     Variadic<Index>:$indices,
     Optional<AnyType>:$padding, Optional<AnyVectorOfNonZeroRank>:$mask,
     ArmSME_TileSliceLayoutAttr:$layout
@@ -407,6 +408,7 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
 def TileStoreOp : ArmSME_Op<"tile_store", [
   ArmSMETileOpInterface,
   AttrSizedOperandSegments,
+  AllElementTypesMatch<["valueToStore", "base"]>,
   HasMatchingMaskTypeConstraint<"valueToStore", "mask">,
 ]> {
   let summary = "Tile store operation";
@@ -443,7 +445,7 @@ def TileStoreOp : ArmSME_Op<"tile_store", [
     ```
   }];
   let arguments = (ins SMETile:$valueToStore,
-    Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
+    Arg<MemRefRankOf<[AnyType], [2]>, "the reference to store to", [MemWrite]>:$base,
     Variadic<Index>:$indices, Optional<AnyVectorOfNonZeroRank>:$mask,
     ArmSME_TileSliceLayoutAttr:$layout
   );
@@ -473,6 +475,7 @@ def TileStoreOp : ArmSME_Op<"tile_store", [
 
 def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
   ArmSMETileOpInterface,
+  AllElementTypesMatch<["tile", "base"]>,
   AllTypesMatch<["tile", "result"]>, TileSliceMaskConstraint<"result", "mask">
 ]> {
   let summary = "Tile slice load and update operation";
@@ -535,6 +538,7 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
 
 def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [
   ArmSMETileOpInterface,
+  AllElementTypesMatch<["tile", "base"]>,
   TileSliceMaskConstraint<"tile", "mask">
 ]> {
   let summary = "Tile slice store operation";
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 630414030d98b..458628c29c6ac 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -33,20 +33,15 @@ SmallVector<Value, 2> getMemrefIndices(ValueRange indices, unsigned rank,
                                        Value tileSliceIndex,
                                        Value tileSliceNumElts, Location loc,
                                        PatternRewriter &rewriter) {
-  assert((rank == 1 || rank == 2) && "memref has unexpected rank!");
+  assert(rank == 2 && "memref has unexpected rank!");
   SmallVector<Value, 2> outIndices;
 
   auto tileSliceOffset = tileSliceIndex;
-  if (rank == 1)
-    tileSliceOffset =
-        rewriter.create<arith::MulIOp>(loc, tileSliceOffset, tileSliceNumElts);
 
   auto baseIndexPlusTileSliceOffset =
       rewriter.create<arith::AddIOp>(loc, indices[0], tileSliceOffset);
   outIndices.push_back(baseIndexPlusTileSliceOffset);
-
-  if (rank == 2)
-    outIndices.push_back(indices[1]);
+  outIndices.push_back(indices[1]);
 
   return outIndices;
 }
@@ -60,6 +55,10 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
         makeLoopBody) {
   PatternRewriter::InsertionGuard guard(rewriter);
 
+  // TODO: This case should be captured and rejected by a verifier.
+  if (memrefIndices.size() != 2)
+    return rewriter.notifyMatchFailure(loc, "invalid number of indices");
+
   auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
       loc, arm_sme::getSMETileSliceMinNumElts(tileType.getElementType()));
   auto vscale =
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index 700b2412ff7a7..8c5a098a0c785 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -50,7 +50,7 @@ func.func @arm_sme_get_tile__bad_shape(%tile_id : i8) -> vector<[4]x[16]xi8> {
 
 // -----
 
-func.func @arm_sme_insert_tile_slice_i8__bad_vector_type(%vector : vector<[8]xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> vector<[16]x[16]xi8> {
+func.func @arm_sme_insert_tile_slice_i8__bad_vector_length(%vector : vector<[8]xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> vector<[16]x[16]xi8> {
   %c0 = arith.constant 0 : index
   // expected-error at +1 {{op failed to verify that type of 'vector' matches type of 'tile' slice}}
   %0 = arm_sme.insert_tile_slice %vector, %tile[%tile_slice_index] : vector<[8]xi8> into vector<[16]x[16]xi8>
@@ -59,23 +59,40 @@ func.func @arm_sme_insert_tile_slice_i8__bad_vector_type(%vector : vector<[8]xi8
 
 // -----
 
-func.func @arm_sme_insert_tile_slice_f32__bad_vector_type(%vector : vector<[8]xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[4]x[4]xf32> {
+func.func @arm_sme_insert_tile_slice_f32__bad_vector_length(%vector : vector<[8]xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[4]x[4]xf32> {
   %c0 = arith.constant 0 : index
   // expected-error at +1 {{op failed to verify that type of 'vector' matches type of 'tile' slice}}
   %0 = arm_sme.insert_tile_slice %vector, %tile[%tile_slice_index] : vector<[8]xf32> into vector<[4]x[4]xf32>
   return %0 : vector<[4]x[4]xf32>
 }
 
+// -----
+
+func.func @arm_sme_insert_tile_slice__bad_element_type(%vector : vector<[4]xf64>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[4]x[4]xf32> {
+  %c0 = arith.constant 0 : index
+  // expected-error at +1 {{op failed to verify that type of 'vector' matches type of 'tile' slice}}
+  %0 = arm_sme.insert_tile_slice %vector, %tile[%tile_slice_index] : vector<[4]xf64> into vector<[4]x[4]xf32>
+  return %0 : vector<[4]x[4]xf32>
+}
+
 //===----------------------------------------------------------------------===//
 // arm_sme.extract_tile_slice
 //===----------------------------------------------------------------------===//
 
 // -----
 
-func.func @arm_sme_extract_tile_slice__bad_result_type(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[2]xf64> {
+func.func @arm_sme_extract_tile_slice__bad_result_length(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[2]xf32> {
+  // expected-error at +1 {{op failed to verify that type of 'result' matches type of 'tile' slice}}
+  %0 = arm_sme.extract_tile_slice %tile[%tile_slice_index] : vector<[2]xf32> from vector<[4]x[4]xf32>
+  return %0 : vector<[2]xf32>
+}
+
+// -----
+
+func.func @arm_sme_extract_tile_slice__bad_result_element_type(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[4]xf64> {
   // expected-error at +1 {{op failed to verify that type of 'result' matches type of 'tile' slice}}
-  %0 = arm_sme.extract_tile_slice %tile[%tile_slice_index] : vector<[2]xf64> from vector<[4]x[4]xf32>
-  return %0 : vector<[2]xf64>
+  %0 = arm_sme.extract_tile_slice %tile[%tile_slice_index] : vector<[4]xf64> from vector<[4]x[4]xf32>
+  return %0 : vector<[4]xf64>
 }
 
 //===----------------------------------------------------------------------===//
@@ -111,6 +128,24 @@ func.func @arm_sme_tile_load__pad_but_no_mask(%src : memref<?x?xf64>, %pad : f64
   return
 }
 
+// -----
+
+func.func @arm_sme_tile_load__bad_memref_rank(%src : memref<?xf64>, %pad : f64) {
+  %c0 = arith.constant 0 : index
+  // expected-error at +1 {{op operand #0 must be 2D memref of any type values, but got 'memref<?xf64>'}}
+  %tile = arm_sme.tile_load %src[%c0], %pad, : memref<?xf64>, vector<[2]x[2]xf64>
+  return
+}
+
+// -----
+
+func.func @arm_sme_tile_load__bad_element_type(%src : memref<?x?xi32>) {
+  %c0 = arith.constant 0 : index
+  // expected-error at +1 {{failed to verify that all of {result, base} have same element type}}
+  %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi32>, vector<[16]x[16]xi8>
+  return
+}
+
 //===----------------------------------------------------------------------===//
 // arm_sme.load_tile_slice
 //===----------------------------------------------------------------------===//
@@ -124,6 +159,15 @@ func.func @arm_sme_load_tile_slice__bad_mask_type(%src : memref<?x?xi8>, %mask :
   return
 }
 
+// -----
+
+func.func @arm_sme_load_tile_slice__bad_element_type(%src : memref<?x?xi32>, %mask : vector<[16]xi1>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
+  %c0 = arith.constant 0 : index
+  // expected-error at +1 {{op failed to verify that all of {tile, base} have same element type}}
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi32>, vector<[16]xi1>, vector<[16]x[16]xi8>
+  return
+}
+
 //===----------------------------------------------------------------------===//
 // arm_sme.tile_store
 //===----------------------------------------------------------------------===//
@@ -138,6 +182,24 @@ func.func @arm_sme_tile_store__bad_mask_type(%tile : vector<[16]x[16]xi8>, %mask
   return
 }
 
+// -----
+
+func.func @arm_sme_tile_store__bad_memref_rank(%tile : vector<[16]x[16]xi8>, %dest : memref<?xi8>) {
+  %c0 = arith.constant 0 : index
+  // expected-error at +1 {{op operand #1 must be 2D memref of any type values, but got 'memref<?xi8>'}}
+  arm_sme.tile_store %tile, %dest[%c0] : memref<?xi8>, vector<[16]x[16]xi8>
+  return
+}
+
+// -----
+
+func.func @arm_sme_tile_store__bad_element_type(%tile : vector<[16]x[16]xi8>, %dest : memref<?x?xi32>) {
+  %c0 = arith.constant 0 : index
+  // expected-error at +1 {{op failed to verify that all of {valueToStore, base} have same element type}}
+  arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi32>, vector<[16]x[16]xi8>
+  return
+}
+
 //===----------------------------------------------------------------------===//
 // arm_sme.store_tile_slice
 //===----------------------------------------------------------------------===//
@@ -152,6 +214,15 @@ func.func @arm_sme_store_tile_slice__bad_mask_type(%tile : vector<[16]x[16]xi8>,
   return
 }
 
+// -----
+
+func.func @arm_sme_store_tile_slice__bad_element_type(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %mask : vector<[16]xi1>, %dest : memref<?x?xi32>) -> () {
+  %c0 = arith.constant 0 : index
+  // expected-error at +1 {{op failed to verify that all of {tile, base} have same element type}}
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0, %c0] : memref<?x?xi32>, vector<[16]xi1>, vector<[16]x[16]xi8>
+  return
+}
+
 //===----------------------------------------------------------------------===//
 // arm_sme.outerproduct
 //===----------------------------------------------------------------------===//

>From 579bff7bf3019f0c164cd22e6000f87fd1f3433d Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 13 May 2025 08:24:49 +0000
Subject: [PATCH 2/2] fixup! [mlir][ArmSME] Audit ArmSME load/store ops

Update test
---
 .../Dialect/Vector/CPU/ArmSME/load-vertical.mlir     | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-vertical.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-vertical.mlir
index b7144be08a853..8d4b4a07994e2 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-vertical.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-vertical.mlir
@@ -17,7 +17,7 @@ func.func @entry() {
   %za_s_size = arith.muli %svl_s, %svl_s : index
 
   // Allocate memory.
-  %mem1 = memref.alloca(%za_s_size) : memref<?xi32>
+  %mem1 = memref.alloca(%svl_s, %svl_s) : memref<?x?xi32>
 
   // Fill each "row" of "mem1" with row number.
   //
@@ -29,15 +29,15 @@ func.func @entry() {
   //   3, 3, 3, 3
   //
   %init_0 = arith.constant 0 : i32
-  scf.for %i = %c0 to %za_s_size step %svl_s iter_args(%val = %init_0) -> (i32) {
+  scf.for %i = %c0 to %svl_s step %c1 iter_args(%val = %init_0) -> (i32) {
     %splat_val = vector.broadcast %val : i32 to vector<[4]xi32>
-    vector.store %splat_val, %mem1[%i] : memref<?xi32>, vector<[4]xi32>
+    vector.store %splat_val, %mem1[%i, %c0] : memref<?x?xi32>, vector<[4]xi32>
     %val_next = arith.addi %val, %c1_i32 : i32
     scf.yield %val_next : i32
   }
 
   // Load tile from "mem1" vertically.
-  %0 = arm_sme.tile_load %mem1[%c0, %c0] layout<vertical> : memref<?xi32>, vector<[4]x[4]xi32>
+  %0 = arm_sme.tile_load %mem1[%c0, %c0] layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
 
   // 1. ORIGINAL HORIZONTAL LAYOUT
   // Dump "mem1". The smallest SVL is 128-bits so the tile will be at least
@@ -50,8 +50,8 @@ func.func @entry() {
   // CHECK-NEXT: ( 3, 3, 3, 3
   // CHECK:      TILE END
   vector.print str "TILE BEGIN\n"
-  scf.for %i = %c0 to %za_s_size step %svl_s {
-    %tileslice = vector.load %mem1[%i] : memref<?xi32>, vector<[4]xi32>
+  scf.for %i = %c0 to %svl_s step %c1 {
+    %tileslice = vector.load %mem1[%i, %c0] : memref<?x?xi32>, vector<[4]xi32>
     vector.print %tileslice : vector<[4]xi32>
   }
   vector.print str "TILE END\n"



More information about the Mlir-commits mailing list