[Mlir-commits] [mlir] ba2b21a - [mlir][ArmSME] Audit ArmSME load/store ops (#139573)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 13 05:01:05 PDT 2025


Author: Andrzej Warzyński
Date: 2025-05-13T13:01:01+01:00
New Revision: ba2b21a584219055c1c8106ba81ca49db538a6a5

URL: https://github.com/llvm/llvm-project/commit/ba2b21a584219055c1c8106ba81ca49db538a6a5
DIFF: https://github.com/llvm/llvm-project/commit/ba2b21a584219055c1c8106ba81ca49db538a6a5.diff

LOG: [mlir][ArmSME] Audit ArmSME load/store ops (#139573)

This patch updates the following ArmSME ops to require that input and
output element types match:
  * `arm_sme.tile_load`, `arm_sme.tile_store`,
    `arm_sme.tile_load_slice`, `arm_sme.tile_store_slice`.

In addition, it ensures that the base memref operand for `tile_load` and
`tile_store` is always rank-2, aligning with the semantics of Arm SME
tiles (always rank-2). This change is effectively a follow-up to
#135151:

  * "[mlir][vector] Tighten the semantics of vector.{load|store}"

The patch also updates `createLoadStoreForOverTileSlices` in
ArmSMEToSCF.cpp to fail when processing invalid tile stores like the
following:

```mlir
arm_sme.tile_store %arg0, %arg1[%c0] : memref<?x4xi8>, vector<[4]x[4]xi32>
```

This particular change fixes #118769. As noted in the TODO, we should
further extend op verification logic — I plan to address that in a
follow-up patch.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
    mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
    mlir/test/Dialect/ArmSME/invalid.mlir
    mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-vertical.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 6fd992afbf043..2f083b55d4904 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -317,6 +317,7 @@ def CopyTileOp : ArmSME_Op<"copy_tile", [
 def TileLoadOp : ArmSME_Op<"tile_load", [
   ArmSMETileOpInterface,
   AttrSizedOperandSegments,
+  AllElementTypesMatch<["result", "base"]>,
   OptionalTypesMatchWith<
     "padding type matches element type of result",
     "result", "padding",
@@ -369,7 +370,7 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
     ```
   }];
   let arguments = (ins
-    Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
+    Arg<MemRefRankOf<[AnyType], [2]>, "the reference to load from", [MemRead]>:$base,
     Variadic<Index>:$indices,
     Optional<AnyType>:$padding, Optional<AnyVectorOfNonZeroRank>:$mask,
     ArmSME_TileSliceLayoutAttr:$layout
@@ -407,6 +408,7 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
 def TileStoreOp : ArmSME_Op<"tile_store", [
   ArmSMETileOpInterface,
   AttrSizedOperandSegments,
+  AllElementTypesMatch<["valueToStore", "base"]>,
   HasMatchingMaskTypeConstraint<"valueToStore", "mask">,
 ]> {
   let summary = "Tile store operation";
@@ -443,7 +445,7 @@ def TileStoreOp : ArmSME_Op<"tile_store", [
     ```
   }];
   let arguments = (ins SMETile:$valueToStore,
-    Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
+    Arg<MemRefRankOf<[AnyType], [2]>, "the reference to store to", [MemWrite]>:$base,
     Variadic<Index>:$indices, Optional<AnyVectorOfNonZeroRank>:$mask,
     ArmSME_TileSliceLayoutAttr:$layout
   );
@@ -473,6 +475,7 @@ def TileStoreOp : ArmSME_Op<"tile_store", [
 
 def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
   ArmSMETileOpInterface,
+  AllElementTypesMatch<["tile", "base"]>,
   AllTypesMatch<["tile", "result"]>, TileSliceMaskConstraint<"result", "mask">
 ]> {
   let summary = "Tile slice load and update operation";
@@ -535,6 +538,7 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
 
 def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [
   ArmSMETileOpInterface,
+  AllElementTypesMatch<["tile", "base"]>,
   TileSliceMaskConstraint<"tile", "mask">
 ]> {
   let summary = "Tile slice store operation";

diff  --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 630414030d98b..458628c29c6ac 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -33,20 +33,15 @@ SmallVector<Value, 2> getMemrefIndices(ValueRange indices, unsigned rank,
                                        Value tileSliceIndex,
                                        Value tileSliceNumElts, Location loc,
                                        PatternRewriter &rewriter) {
-  assert((rank == 1 || rank == 2) && "memref has unexpected rank!");
+  assert(rank == 2 && "memref has unexpected rank!");
   SmallVector<Value, 2> outIndices;
 
   auto tileSliceOffset = tileSliceIndex;
-  if (rank == 1)
-    tileSliceOffset =
-        rewriter.create<arith::MulIOp>(loc, tileSliceOffset, tileSliceNumElts);
 
   auto baseIndexPlusTileSliceOffset =
       rewriter.create<arith::AddIOp>(loc, indices[0], tileSliceOffset);
   outIndices.push_back(baseIndexPlusTileSliceOffset);
-
-  if (rank == 2)
-    outIndices.push_back(indices[1]);
+  outIndices.push_back(indices[1]);
 
   return outIndices;
 }
@@ -60,6 +55,10 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
         makeLoopBody) {
   PatternRewriter::InsertionGuard guard(rewriter);
 
+  // TODO: This case should be captured and rejected by a verifier.
+  if (memrefIndices.size() != 2)
+    return rewriter.notifyMatchFailure(loc, "invalid number of indices");
+
   auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
       loc, arm_sme::getSMETileSliceMinNumElts(tileType.getElementType()));
   auto vscale =

diff  --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index 700b2412ff7a7..8c5a098a0c785 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -50,7 +50,7 @@ func.func @arm_sme_get_tile__bad_shape(%tile_id : i8) -> vector<[4]x[16]xi8> {
 
 // -----
 
-func.func @arm_sme_insert_tile_slice_i8__bad_vector_type(%vector : vector<[8]xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> vector<[16]x[16]xi8> {
+func.func @arm_sme_insert_tile_slice_i8__bad_vector_length(%vector : vector<[8]xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> vector<[16]x[16]xi8> {
   %c0 = arith.constant 0 : index
   // expected-error at +1 {{op failed to verify that type of 'vector' matches type of 'tile' slice}}
   %0 = arm_sme.insert_tile_slice %vector, %tile[%tile_slice_index] : vector<[8]xi8> into vector<[16]x[16]xi8>
@@ -59,23 +59,40 @@ func.func @arm_sme_insert_tile_slice_i8__bad_vector_type(%vector : vector<[8]xi8
 
 // -----
 
-func.func @arm_sme_insert_tile_slice_f32__bad_vector_type(%vector : vector<[8]xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[4]x[4]xf32> {
+func.func @arm_sme_insert_tile_slice_f32__bad_vector_length(%vector : vector<[8]xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[4]x[4]xf32> {
   %c0 = arith.constant 0 : index
   // expected-error at +1 {{op failed to verify that type of 'vector' matches type of 'tile' slice}}
   %0 = arm_sme.insert_tile_slice %vector, %tile[%tile_slice_index] : vector<[8]xf32> into vector<[4]x[4]xf32>
   return %0 : vector<[4]x[4]xf32>
 }
 
+// -----
+
+func.func @arm_sme_insert_tile_slice__bad_element_type(%vector : vector<[4]xf64>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[4]x[4]xf32> {
+  %c0 = arith.constant 0 : index
+  // expected-error at +1 {{op failed to verify that type of 'vector' matches type of 'tile' slice}}
+  %0 = arm_sme.insert_tile_slice %vector, %tile[%tile_slice_index] : vector<[4]xf64> into vector<[4]x[4]xf32>
+  return %0 : vector<[4]x[4]xf32>
+}
+
 //===----------------------------------------------------------------------===//
 // arm_sme.extract_tile_slice
 //===----------------------------------------------------------------------===//
 
 // -----
 
-func.func @arm_sme_extract_tile_slice__bad_result_type(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[2]xf64> {
+func.func @arm_sme_extract_tile_slice__bad_result_length(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[2]xf32> {
+  // expected-error at +1 {{op failed to verify that type of 'result' matches type of 'tile' slice}}
+  %0 = arm_sme.extract_tile_slice %tile[%tile_slice_index] : vector<[2]xf32> from vector<[4]x[4]xf32>
+  return %0 : vector<[2]xf32>
+}
+
+// -----
+
+func.func @arm_sme_extract_tile_slice__bad_result_element_type(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[4]xf64> {
   // expected-error at +1 {{op failed to verify that type of 'result' matches type of 'tile' slice}}
-  %0 = arm_sme.extract_tile_slice %tile[%tile_slice_index] : vector<[2]xf64> from vector<[4]x[4]xf32>
-  return %0 : vector<[2]xf64>
+  %0 = arm_sme.extract_tile_slice %tile[%tile_slice_index] : vector<[4]xf64> from vector<[4]x[4]xf32>
+  return %0 : vector<[4]xf64>
 }
 
 //===----------------------------------------------------------------------===//
@@ -111,6 +128,24 @@ func.func @arm_sme_tile_load__pad_but_no_mask(%src : memref<?x?xf64>, %pad : f64
   return
 }
 
+// -----
+
+func.func @arm_sme_tile_load__bad_memref_rank(%src : memref<?xf64>, %pad : f64) {
+  %c0 = arith.constant 0 : index
+  // expected-error at +1 {{op operand #0 must be 2D memref of any type values, but got 'memref<?xf64>'}}
+  %tile = arm_sme.tile_load %src[%c0], %pad, : memref<?xf64>, vector<[2]x[2]xf64>
+  return
+}
+
+// -----
+
+func.func @arm_sme_tile_load__bad_element_type(%src : memref<?x?xi32>) {
+  %c0 = arith.constant 0 : index
+  // expected-error at +1 {{failed to verify that all of {result, base} have same element type}}
+  %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi32>, vector<[16]x[16]xi8>
+  return
+}
+
 //===----------------------------------------------------------------------===//
 // arm_sme.load_tile_slice
 //===----------------------------------------------------------------------===//
@@ -124,6 +159,15 @@ func.func @arm_sme_load_tile_slice__bad_mask_type(%src : memref<?x?xi8>, %mask :
   return
 }
 
+// -----
+
+func.func @arm_sme_load_tile_slice__bad_element_type(%src : memref<?x?xi32>, %mask : vector<[16]xi1>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
+  %c0 = arith.constant 0 : index
+  // expected-error at +1 {{op failed to verify that all of {tile, base} have same element type}}
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi32>, vector<[16]xi1>, vector<[16]x[16]xi8>
+  return
+}
+
 //===----------------------------------------------------------------------===//
 // arm_sme.tile_store
 //===----------------------------------------------------------------------===//
@@ -138,6 +182,24 @@ func.func @arm_sme_tile_store__bad_mask_type(%tile : vector<[16]x[16]xi8>, %mask
   return
 }
 
+// -----
+
+func.func @arm_sme_tile_store__bad_memref_rank(%tile : vector<[16]x[16]xi8>, %dest : memref<?xi8>) {
+  %c0 = arith.constant 0 : index
+  // expected-error at +1 {{op operand #1 must be 2D memref of any type values, but got 'memref<?xi8>'}}
+  arm_sme.tile_store %tile, %dest[%c0] : memref<?xi8>, vector<[16]x[16]xi8>
+  return
+}
+
+// -----
+
+func.func @arm_sme_tile_store__bad_element_type(%tile : vector<[16]x[16]xi8>, %dest : memref<?x?xi32>) {
+  %c0 = arith.constant 0 : index
+  // expected-error at +1 {{op failed to verify that all of {valueToStore, base} have same element type}}
+  arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi32>, vector<[16]x[16]xi8>
+  return
+}
+
 //===----------------------------------------------------------------------===//
 // arm_sme.store_tile_slice
 //===----------------------------------------------------------------------===//
@@ -152,6 +214,15 @@ func.func @arm_sme_store_tile_slice__bad_mask_type(%tile : vector<[16]x[16]xi8>,
   return
 }
 
+// -----
+
+func.func @arm_sme_store_tile_slice__bad_element_type(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %mask : vector<[16]xi1>, %dest : memref<?x?xi32>) -> () {
+  %c0 = arith.constant 0 : index
+  // expected-error at +1 {{op failed to verify that all of {tile, base} have same element type}}
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0, %c0] : memref<?x?xi32>, vector<[16]xi1>, vector<[16]x[16]xi8>
+  return
+}
+
 //===----------------------------------------------------------------------===//
 // arm_sme.outerproduct
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-vertical.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-vertical.mlir
index b7144be08a853..8d4b4a07994e2 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-vertical.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-vertical.mlir
@@ -17,7 +17,7 @@ func.func @entry() {
   %za_s_size = arith.muli %svl_s, %svl_s : index
 
   // Allocate memory.
-  %mem1 = memref.alloca(%za_s_size) : memref<?xi32>
+  %mem1 = memref.alloca(%svl_s, %svl_s) : memref<?x?xi32>
 
   // Fill each "row" of "mem1" with row number.
   //
@@ -29,15 +29,15 @@ func.func @entry() {
   //   3, 3, 3, 3
   //
   %init_0 = arith.constant 0 : i32
-  scf.for %i = %c0 to %za_s_size step %svl_s iter_args(%val = %init_0) -> (i32) {
+  scf.for %i = %c0 to %svl_s step %c1 iter_args(%val = %init_0) -> (i32) {
     %splat_val = vector.broadcast %val : i32 to vector<[4]xi32>
-    vector.store %splat_val, %mem1[%i] : memref<?xi32>, vector<[4]xi32>
+    vector.store %splat_val, %mem1[%i, %c0] : memref<?x?xi32>, vector<[4]xi32>
     %val_next = arith.addi %val, %c1_i32 : i32
     scf.yield %val_next : i32
   }
 
   // Load tile from "mem1" vertically.
-  %0 = arm_sme.tile_load %mem1[%c0, %c0] layout<vertical> : memref<?xi32>, vector<[4]x[4]xi32>
+  %0 = arm_sme.tile_load %mem1[%c0, %c0] layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
 
   // 1. ORIGINAL HORIZONTAL LAYOUT
   // Dump "mem1". The smallest SVL is 128-bits so the tile will be at least
@@ -50,8 +50,8 @@ func.func @entry() {
   // CHECK-NEXT: ( 3, 3, 3, 3
   // CHECK:      TILE END
   vector.print str "TILE BEGIN\n"
-  scf.for %i = %c0 to %za_s_size step %svl_s {
-    %tileslice = vector.load %mem1[%i] : memref<?xi32>, vector<[4]xi32>
+  scf.for %i = %c0 to %svl_s step %c1 {
+    %tileslice = vector.load %mem1[%i, %c0] : memref<?x?xi32>, vector<[4]xi32>
     vector.print %tileslice : vector<[4]xi32>
   }
   vector.print str "TILE END\n"


        


More information about the Mlir-commits mailing list