[Mlir-commits] [mlir] [mlir][ArmSME] Audit arm_sme.tile_store (PR #135396)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Fri Apr 11 12:09:50 PDT 2025
https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/135396
>From 93cfa92cd41811cce3da511b7c0221ab95d1a292 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 11 Apr 2025 16:03:07 +0000
Subject: [PATCH] [mlir][ArmSME] Audit arm_sme.tile_store
Makes that the following cases are rejected (as opposed causing the
mlir-opt to crash):
```mlir
arm_sme.tile_store %arg0, %arg1[%c0] : memref<?x4xi8>, vector<[4]x[4]xi32>
arm_sme.tile_store %arg0, %arg1[%c0] : memref<?xi8>, vector<[4]x[4]xi32>
```
Instead, we should be loading from a rank-2 MemRef, using 2 indices,
e.g.:
```mlir
arm_sme.tile_store %arg0, %arg1[%c0, %c0] : memref<?x?xi8>, vector<[4]x[4]xi32>
```
Fixes https://github.com/llvm/llvm-project/issues/118769
---
.../mlir/Dialect/ArmSME/IR/ArmSMEOps.td | 4 ++--
.../lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp | 13 ++++++-------
mlir/test/Dialect/ArmSME/invalid.mlir | 18 ++++++++++++++++++
3 files changed, 26 insertions(+), 9 deletions(-)
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 6fd992afbf043..23eab706c856d 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -369,7 +369,7 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
```
}];
let arguments = (ins
- Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
+ Arg<MemRefRankOf<[AnyType], [2]>, "the reference to load from", [MemRead]>:$base,
Variadic<Index>:$indices,
Optional<AnyType>:$padding, Optional<AnyVectorOfNonZeroRank>:$mask,
ArmSME_TileSliceLayoutAttr:$layout
@@ -443,7 +443,7 @@ def TileStoreOp : ArmSME_Op<"tile_store", [
```
}];
let arguments = (ins SMETile:$valueToStore,
- Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
+ Arg<MemRefRankOf<[AnyType], [2]>, "the reference to store to", [MemWrite]>:$base,
Variadic<Index>:$indices, Optional<AnyVectorOfNonZeroRank>:$mask,
ArmSME_TileSliceLayoutAttr:$layout
);
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 6ed29903ea407..9bdafb7d8c501 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -33,20 +33,15 @@ SmallVector<Value, 2> getMemrefIndices(ValueRange indices, unsigned rank,
Value tileSliceIndex,
Value tileSliceNumElts, Location loc,
PatternRewriter &rewriter) {
- assert((rank == 1 || rank == 2) && "memref has unexpected rank!");
+ assert(rank == 2 && "memref has unexpected rank!");
SmallVector<Value, 2> outIndices;
auto tileSliceOffset = tileSliceIndex;
- if (rank == 1)
- tileSliceOffset =
- rewriter.create<arith::MulIOp>(loc, tileSliceOffset, tileSliceNumElts);
auto baseIndexPlusTileSliceOffset =
rewriter.create<arith::AddIOp>(loc, indices[0], tileSliceOffset);
outIndices.push_back(baseIndexPlusTileSliceOffset);
-
- if (rank == 2)
- outIndices.push_back(indices[1]);
+ outIndices.push_back(indices[1]);
return outIndices;
}
@@ -60,6 +55,10 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
makeLoopBody) {
PatternRewriter::InsertionGuard guard(rewriter);
+ // TODO: This case should be captured and rejected by a verifier.
+ if (memrefIndices.size() != 2)
+ return rewriter.notifyMatchFailure(loc, "invalid number of indices");
+
auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
loc, arm_sme::getSMETileSliceMinNumElts(tileType.getElementType()));
auto vscale =
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index 700b2412ff7a7..c015fe7cf1641 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -111,6 +111,15 @@ func.func @arm_sme_tile_load__pad_but_no_mask(%src : memref<?x?xf64>, %pad : f64
return
}
+// -----
+
+func.func @arm_sme_tile_load__bad_memref_rank(%src : memref<?xf64>, %pad : f64) {
+ %c0 = arith.constant 0 : index
+ // expected-error at +1 {{op operand #0 must be 2D memref of any type values, but got 'memref<?xf64>'}}
+ %tile = arm_sme.tile_load %src[%c0], %pad, : memref<?xf64>, vector<[2]x[2]xf64>
+ return
+}
+
//===----------------------------------------------------------------------===//
// arm_sme.load_tile_slice
//===----------------------------------------------------------------------===//
@@ -138,6 +147,15 @@ func.func @arm_sme_tile_store__bad_mask_type(%tile : vector<[16]x[16]xi8>, %mask
return
}
+// -----
+
+func.func @arm_sme_tile_store__bad_memref_rank(%tile : vector<[16]x[16]xi8>, %dest : memref<?xi8>) {
+ %c0 = arith.constant 0 : index
+ // expected-error at +1 {{op operand #1 must be 2D memref of any type values, but got 'memref<?xi8>'}}
+ arm_sme.tile_store %tile, %dest[%c0] : memref<?xi8>, vector<[16]x[16]xi8>
+ return
+}
+
//===----------------------------------------------------------------------===//
// arm_sme.store_tile_slice
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list