[Mlir-commits] [mlir] 65a6be5 - [mlir][ArmSME] Use memref indices for load and store
Cullen Rhodes
llvmlistbot at llvm.org
Thu Aug 3 01:50:26 PDT 2023
Author: Cullen Rhodes
Date: 2023-08-03T08:50:12Z
New Revision: 65a6be5de97adec09534d37019e384e374069ce7
URL: https://github.com/llvm/llvm-project/commit/65a6be5de97adec09534d37019e384e374069ce7
DIFF: https://github.com/llvm/llvm-project/commit/65a6be5de97adec09534d37019e384e374069ce7.diff
LOG: [mlir][ArmSME] Use memref indices for load and store
This patch extends the ArmSME load and store op lowering to use the
memref indices. An integration test that loads two 32-bit element ZA
tiles from memory and stores them back to memory in reverse order to
verify this is added.
Depends on D156467 D156558
Reviewed By: awarzynski, dcaballe
Differential Revision: https://reviews.llvm.org/D156689
Added:
Modified:
mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index e143726cf234f2..4028a7ad0870b5 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -26,6 +26,28 @@ namespace mlir {
using namespace mlir;
namespace {
+/// Adjusts `indices` as follows for a given tile slice and returns them in
+/// `outIndices`:
+/// rank 1: (indices[0] + (tileSliceIndex * tileSliceNumElts))
+/// rank 2: (indices[0] + tileSliceIndex, indices[1])
+void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex,
+ Value tileSliceNumElts,
+ SmallVectorImpl<Value> &outIndices, Location loc,
+ PatternRewriter &rewriter) {
+ assert((rank == 1 || rank == 2) && "memref has unexpected rank!");
+
+ auto tileSliceOffset = tileSliceIndex;
+ if (rank == 1)
+ tileSliceOffset =
+ rewriter.create<arith::MulIOp>(loc, tileSliceOffset, tileSliceNumElts);
+
+ auto baseIndexPlusTileSliceOffset =
+ rewriter.create<arith::AddIOp>(loc, indices[0], tileSliceOffset);
+ outIndices.push_back(baseIndexPlusTileSliceOffset);
+
+ if (rank == 2)
+ outIndices.push_back(indices[1]);
+}
/// Lower `arm_sme.tile_load` to a loop over the tile slices and load each slice
/// using `arm_sme.load_tile_slice`.
@@ -77,6 +99,9 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
auto vscale =
rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ // This describes both the number of ZA tile slices and the number of
+ // elements in a vector of SVL bits for a given element type (SVL_B, SVL_H,
+ // ..., SVL_Q).
auto numTileSlices =
rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
auto forOp =
@@ -84,13 +109,16 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
rewriter.setInsertionPointToStart(forOp.getBody());
+ // Create 'arm_sme.load_tile_slice' to load tile slice from memory into
+ // tile.
+ SmallVector<Value> memrefIndices;
auto tileSliceIndex = forOp.getInductionVar();
- // TODO: use indices
- // Create 'arm_sme.load_tile_slice' to load tile slice from
- // memory into tile.
- rewriter.create<arm_sme::LoadTileSliceOp>(
- loc, tileType, tileLoadOp.getBase(), tile, tileSliceIndex,
- tileSliceIndex);
+ getMemrefIndices(tileLoadOp.getIndices(),
+ tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
+ numTileSlices, memrefIndices, loc, rewriter);
+ rewriter.create<arm_sme::LoadTileSliceOp>(loc, tileType,
+ tileLoadOp.getBase(), tile,
+ memrefIndices, tileSliceIndex);
rewriter.setInsertionPointAfter(forOp);
@@ -139,6 +167,9 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
auto vscale =
rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ // This describes both the number of ZA tile slices and the number of
+ // elements in a vector of SVL bits for a given element type (SVL_B, SVL_H,
+ // ..., SVL_Q).
auto numTileSlices =
rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
auto forOp =
@@ -146,11 +177,14 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
rewriter.setInsertionPointToStart(forOp.getBody());
+ SmallVector<Value> memrefIndices;
auto tileSliceIndex = forOp.getInductionVar();
- // TODO: use indices
+ getMemrefIndices(tileStoreOp.getIndices(),
+ tileStoreOp.getMemRefType().getRank(), tileSliceIndex,
+ numTileSlices, memrefIndices, loc, rewriter);
rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex,
- tileStoreOp.getBase(), tileSliceIndex);
+ tileStoreOp.getBase(), memrefIndices);
return success();
}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index d61bde971647dc..e1df09ff9e0758 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -111,33 +111,6 @@ Value castTileIDToI32(Value tile, Location loc,
return tile;
}
-/// Returns the following
-/// * for rank 2 memrefs `tileSliceIndex`, since `getStridedElementPtr` does
-/// the arithmetic.
-/// * for rank 1 memrefs `tileSliceIndex * tileSliceNumElts`, adjusting the
-/// index by the number of elements in a vector of SVL bits.
-/// * otherwise throws an unreachable error.
-Value getTileSlicePtrIndex(unsigned rank, Value tileSliceIndex,
- Value tileSliceNumElts, Location loc,
- ConversionPatternRewriter &rewriter) {
- assert((rank == 1 || rank == 2) && "memref has unexpected rank!");
-
- auto tileSliceIndexI64 = rewriter.create<arith::IndexCastUIOp>(
- loc, rewriter.getI64Type(), tileSliceIndex);
-
- if (rank == 1) {
- auto tileSliceNumEltsI64 = rewriter.create<arith::IndexCastUIOp>(
- loc, rewriter.getI64Type(), tileSliceNumElts);
- return rewriter.create<arith::MulIOp>(loc, tileSliceIndexI64,
- tileSliceNumEltsI64);
- }
-
- if (rank == 2)
- return tileSliceIndexI64;
-
- llvm_unreachable("memref has unexpected rank!");
-}
-
/// Lower `arm_sme.load_tile_slice` to SME intrinsics.
struct LoadTileSliceToArmSMELowering
: public ConvertOpToLLVMPattern<arm_sme::LoadTileSliceOp> {
@@ -159,25 +132,11 @@ struct LoadTileSliceToArmSMELowering
loc, rewriter.getIntegerType(tileElementWidth),
loadTileSliceOp.getTile());
- auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
- loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
- auto vscale =
- rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
- // This describes both the number of ZA tile slices and the number of
- // elements in a vector of SVL bits for a given element type (SVL_B, SVL_H,
- // ..., SVL_Q).
- auto numTileSlices =
- rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
+ Value ptr = this->getStridedElementPtr(loc, loadTileSliceOp.getMemRefType(),
+ adaptor.getBase(),
+ adaptor.getIndices(), rewriter);
- // Create 'arm_sme.intr.ld1*.horiz' intrinsic to load ZA tile slice.
- auto memRefType = loadTileSliceOp.getMemRefType();
auto tileSlice = loadTileSliceOp.getTileSliceIndex();
- // TODO: The 'indices' argument for the 'base' memref is currently ignored,
- // 'tileSliceIndex' should be added to 'indices[0]'.
- Value tileSliceIndex = getTileSlicePtrIndex(memRefType.getRank(), tileSlice,
- numTileSlices, loc, rewriter);
- Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.getBase(),
- {tileSliceIndex}, rewriter);
// Cast tile slice to i32 for intrinsic.
auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
@@ -192,6 +151,7 @@ struct LoadTileSliceToArmSMELowering
auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
auto tileI32 = castTileIDToI32(tile, loc, rewriter);
+ // Create 'arm_sme.intr.ld1*.horiz' intrinsic to load ZA tile slice.
switch (tileElementWidth) {
default:
llvm_unreachable("unexpected element type!");
@@ -243,25 +203,12 @@ struct StoreTileSliceToArmSMELowering
loc, rewriter.getIntegerType(tileElementWidth),
storeTileSliceOp.getTile());
- auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
- loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
- auto vscale =
- rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
- // This describes both the number of ZA tile slices and the number of
- // elements in a vector of SVL bits for a given element type (SVL_B, SVL_H,
- // ..., SVL_Q).
- auto numTileSlices =
- rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
-
// Create 'arm_sme.intr.st1*.horiz' intrinsic to store ZA tile slice.
- auto memRefType = storeTileSliceOp.getMemRefType();
+ Value ptr = this->getStridedElementPtr(
+ loc, storeTileSliceOp.getMemRefType(), adaptor.getBase(),
+ adaptor.getIndices(), rewriter);
+
auto tileSlice = storeTileSliceOp.getTileSliceIndex();
- // TODO: The 'indices' argument for the 'base' memref is currently ignored,
- // 'tileSliceIndex' should be added to 'indices[0]'.
- Value tileSliceIndex = getTileSlicePtrIndex(memRefType.getRank(), tileSlice,
- numTileSlices, loc, rewriter);
- Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.getBase(),
- {tileSliceIndex}, rewriter);
// Cast tile slice to i32 for intrinsic.
auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index b64c79663038c0..9ab1d79794d765 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -2,15 +2,16 @@
// CHECK-LABEL: func.func @arm_sme_tile_load(
// CHECK-SAME: %[[SRC:.*]]: memref<?x?xi32>) {
-// CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
-// CHECK-NEXT: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
+// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
+// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
// CHECK-NEXT: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
-// CHECK-NEXT: arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[TILE_SLICE_INDEX]]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]x[4]xi32>
+// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
+// CHECK-NEXT: arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]x[4]xi32>
func.func @arm_sme_tile_load(%src : memref<?x?xi32>) {
%c0 = arith.constant 0 : index
%tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
@@ -28,7 +29,8 @@ func.func @arm_sme_tile_load(%src : memref<?x?xi32>) {
// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
// CHECK: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
-// CHECK: arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[DEST]]{{\[}}%[[TILE_SLICE_INDEX]]] : memref<?x?xi32>, vector<[4]x[4]xi32>
+// CHECK: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
+// CHECK: arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[DEST]]{{\[}}%[[OFFSET]], %[[C0]]] : memref<?x?xi32>, vector<[4]x[4]xi32>
func.func @arm_sme_tile_store(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
%c0 = arith.constant 0 : index
arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
index 402da661cfee44..ddd55319d347f3 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
@@ -8,17 +8,19 @@
// CHECK-DAG: %[[MIN_SVL_B:.*]] = arith.constant 16 : index
// CHECK-DAG: %[[C255:.*]] = arith.constant 255 : i32
// CHECK-DAG: %[[PTRUE_ALL:.*]] = arith.constant dense<true> : vector<[16]xi1>
+// CHECK-DAG: %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64
// CHECK-DAG: "arm_sme.intr.zero"(%[[C255]]) : (i32) -> ()
// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8
// CHECK-DAG: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE_IDX]], %[[MIN_SVL_B]] : index
// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] {
-// CHECK: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64
+// CHECK: %[[TILE_SLICE_I64:.*]] = builtin.unrealized_conversion_cast %[[TILE_SLICE]] : index to i64
// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]] : i64
-// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF0]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
+// CHECK-NEXT: %[[OFF1:.*]] = llvm.add %[[OFF0]], %[[C0_I64]] : i64
+// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
// CHECK-NEXT: %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32
// CHECK-NEXT: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
// CHECK-NEXT: "arm_sme.intr.st1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
@@ -31,32 +33,41 @@ func.func @transfer_write_2d_zero_i8(%arg0 : memref<?x?xi8>) {
// -----
-// CHECK-LABEL: @vector_load_i8(
-// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi8>)
+// Load an 8-bit tile from a rank 2 memref with a non-zero offset for the first
+// memref index. This verifies the offset is preserved when materializing the
+// loop of tile slice loads.
+
+// CHECK-LABEL: @vector_load_i8_with_offset(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi8>)
// CHECK-DAG: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8
// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C123:.*]] = arith.constant 123 : index
// CHECK-DAG: %[[MIN_SVL_B:.*]] = arith.constant 16 : index
// CHECK-DAG: %[[PTRUE_ALL:.*]] = arith.constant dense<true> : vector<[16]xi1>
+// CHECK-DAG: %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64
// CHECK-DAG: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE_IDX]], %[[MIN_SVL_B]] : index
// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] {
-// CHECK: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64
+// CHECK-NEXT: %[[TILE_SLICE_PLUS_OFF0:.*]] = arith.addi %[[TILE_SLICE]], %[[C123]] : index
+// CHECK-NEXT: %[[TILE_SLICE_PLUS_OFF0_I64:.*]] = builtin.unrealized_conversion_cast %[[TILE_SLICE_PLUS_OFF0]] : index to i64
// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]] : i64
-// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF0]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
+// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_PLUS_OFF0_I64]], %[[STRIDE0]] : i64
+// CHECK-NEXT: %[[OFF1:.*]] = llvm.add %[[OFF0]], %[[C0_I64]] : i64
+// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
// CHECK-NEXT: %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32
// CHECK-NEXT: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
// CHECK-NEXT: "arm_sme.intr.ld1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
// CHECK-NEXT: }
// CHECK-NEXT: return %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8>
-func.func @vector_load_i8(%arg0 : memref<?x?xi8>) -> vector<[16]x[16]xi8> {
+func.func @vector_load_i8_with_offset(%arg0 : memref<?x?xi8>) -> vector<[16]x[16]xi8> {
%c0 = arith.constant 0 : index
- %tile = vector.load %arg0[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+ %c123 = arith.constant 123 : index
+ %tile = vector.load %arg0[%c123, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
return %tile : vector<[16]x[16]xi8>
}
@@ -75,14 +86,10 @@ func.func @vector_load_i8(%arg0 : memref<?x?xi8>) -> vector<[16]x[16]xi8> {
// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE_IDX]], %[[MIN_SVL_B]] : index
// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] {
-// CHECK-NEXT: %[[VSCALE_1:.*]] = "llvm.intr.vscale"() : () -> i64
-// CHECK-NEXT: %[[VSCALE_IDX_1:.*]] = builtin.unrealized_conversion_cast %[[VSCALE_1]] : i64 to index
-// CHECK-NEXT: %[[SVL_B_1:.*]] = arith.muli %[[VSCALE_IDX_1]], %[[MIN_SVL_B]] : index
-// CHECK-NEXT: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64
-// CHECK-NEXT: %[[SVL_B_I64:.*]] = arith.index_castui %[[SVL_B_1]] : index to i64
-// CHECK-NEXT: %[[TILE_SLICE_IDX:.*]] = arith.muli %[[TILE_SLICE_I64]], %[[SVL_B_I64]] : i64
+// CHECK-NEXT: %[[TILE_SLICE_IDX:.*]] = arith.muli %[[TILE_SLICE]], %[[SVL_B]] : index
+// CHECK-NEXT: %[[TILE_SLICE_IDX_I64:.*]] = builtin.unrealized_conversion_cast %[[TILE_SLICE_IDX]] : index to i64
// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
-// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[TILE_SLICE_IDX]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
+// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[TILE_SLICE_IDX_I64]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
// CHECK-NEXT: %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32
// CHECK-NEXT: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
// CHECK-NEXT: "arm_sme.intr.ld1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
@@ -218,17 +225,19 @@ func.func @vector_load_f64(%arg0 : memref<?x?xf64>) -> vector<[2]x[2]xf64> {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[MIN_SVL_B:.*]] = arith.constant 16 : index
+// CHECK-DAG: %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64
// CHECK-DAG: %[[PTRUE_ALL:.*]] = arith.constant dense<true> : vector<[16]xi1>
// CHECK-DAG: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE_IDX]], %[[MIN_SVL_B]] : index
// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] {
+// CHECK: %[[TILE_SLICE_I64:.*]] = builtin.unrealized_conversion_cast %[[TILE_SLICE]] : index to i64
// CHECK-NEXT: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[16]x[16]xi8> to i8
-// CHECK: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64
// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]] : i64
-// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF0]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
+// CHECK-NEXT: %[[OFF1:.*]] = llvm.add %[[OFF0]], %[[C0_I64]] : i64
+// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
// CHECK-NEXT: %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32
// CHECK-NEXT: %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i8 to i32
// CHECK-NEXT: "arm_sme.intr.st1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
index 0da4b1cf319e6d..72cfc10e7b7cf5 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
@@ -1,21 +1,31 @@
-// RUN: mlir-opt %s -enable-arm-streaming="mode=locally enable-za" \
-// RUN: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
-// RUN: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
-// RUN: -allocate-arm-sme-tiles -test-lower-to-llvm | \
-// RUN: mlir-translate -mlir-to-llvmir | \
-// RUN: %lli_aarch64_cmd --march=aarch64 --mattr="+sve,+sme" \
-// RUN: --entry-function=za0_d_f64 \
-// RUN: --dlopen=%mlir_native_utils_lib_dir/libmlir_c_runner_utils%shlibext | \
-// RUN: FileCheck %s --check-prefix=CHECK-ZA0_D
-
-// Integration test demonstrating load/store to/from SME ZA tile.
+// DEFINE: %{entry_point} = za0_d_f64
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: -enable-arm-streaming="mode=locally enable-za" \
+// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE: %{run} = %mcr_aarch64_cmd \
+// DEFINE: -march=aarch64 -mattr=+sve,+sme \
+// DEFINE: -e %{entry_point} -entry-point-result=i32 \
+// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+
+// RUN: %{compile} | %{run} | FileCheck %s --check-prefix=CHECK-ZA0_D
+
+// REDEFINE: %{entry_point} = load_store_two_za_s_tiles
+// RUN: %{compile} | %{run} | FileCheck %s
+
+// Integration tests demonstrating load/store to/from SME ZA tile.
llvm.func @printF64(f64)
+llvm.func @printI64(i64)
llvm.func @printOpen()
llvm.func @printClose()
llvm.func @printComma()
llvm.func @printNewline()
+llvm.func @printCString(!llvm.ptr<i8>)
+// This test verifies a 64-bit element ZA with FP64 data is correctly
+// loaded/stored to/from memory.
func.func @za0_d_f64() -> i32 {
%c0 = arith.constant 0 : index
%c0_f64 = arith.constant 0.0 : f64
@@ -191,3 +201,174 @@ func.func @za0_d_f64() -> i32 {
%c0_i32 = arith.constant 0 : i32
return %c0_i32 : i32
}
+
+func.func @printTileBegin() {
+ %0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr<array<11 x i8>>
+ %1 = llvm.mlir.constant(0 : index) : i64
+ %2 = llvm.getelementptr %0[%1, %1]
+ : (!llvm.ptr<array<11 x i8>>, i64, i64) -> !llvm.ptr<i8>
+ llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
+ return
+}
+
+func.func @printTileEnd() {
+ %0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr<array<9 x i8>>
+ %1 = llvm.mlir.constant(0 : index) : i64
+ %2 = llvm.getelementptr %0[%1, %1]
+ : (!llvm.ptr<array<9 x i8>>, i64, i64) -> !llvm.ptr<i8>
+ llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
+ return
+}
+
+// This test loads two 32-bit element ZA tiles from memory and stores them back
+// to memory in reverse order. This verifies the memref indices for the vector
+// load and store are correctly preserved since the second tile is offset from
+// the first tile.
+func.func @load_store_two_za_s_tiles() -> i32 {
+ %c0 = arith.constant 0 : index
+ %c0_i32 = arith.constant 0 : i32
+ %c1_i32 = arith.constant 1 : i32
+ %c2_i32 = arith.constant 2 : i32
+ %c1_index = arith.constant 1 : index
+ %c2_index = arith.constant 2 : index
+
+ %min_elts_s = arith.constant 4 : index
+ %vscale = vector.vscale
+
+ // "svl" refers to the Streaming Vector Length and "svl_s" can mean either:
+ // * the number of 32-bit elements in a vector of SVL bits.
+ // * the number of tile slices (1d vectors) in a 32-bit element tile.
+ %svl_s = arith.muli %min_elts_s, %vscale : index
+
+ // Allocate memory for two 32-bit element tiles.
+ %size_of_tile = arith.muli %svl_s, %svl_s : index
+ %size_of_two_tiles = arith.muli %size_of_tile, %c2_index : index
+ %mem1 = memref.alloca(%size_of_two_tiles) : memref<?xi32>
+
+ // Fill memory that tile 1 will be loaded from with '1' and '2' for tile 2.
+ //
+ // For example, assuming an SVL of 128-bits and two 4x4xi32 tiles:
+ //
+ // tile 1
+ //
+ // 1, 1, 1, 1
+ // 1, 1, 1, 1
+ // 1, 1, 1, 1
+ // 1, 1, 1, 1
+ //
+ // tile 2
+ //
+ // 2, 2, 2, 2
+ // 2, 2, 2, 2
+ // 2, 2, 2, 2
+ // 2, 2, 2, 2
+ //
+ scf.for %i = %c0 to %size_of_two_tiles step %svl_s {
+ %isFirstTile = arith.cmpi ult, %i, %size_of_tile : index
+ %val = scf.if %isFirstTile -> i32 {
+ scf.yield %c1_i32 : i32
+ } else {
+ scf.yield %c2_i32 : i32
+ }
+ %splat_val = vector.broadcast %val : i32 to vector<[4]xi32>
+ vector.store %splat_val, %mem1[%i] : memref<?xi32>, vector<[4]xi32>
+ }
+
+ // Dump "mem1". The smallest SVL is 128-bits so each tile will be at least
+ // 4x4xi32.
+ //
+ // CHECK: ( 1, 1, 1, 1
+ // CHECK-NEXT: ( 1, 1, 1, 1
+ // CHECK-NEXT: ( 1, 1, 1, 1
+ // CHECK-NEXT: ( 1, 1, 1, 1
+ // CHECK: ( 2, 2, 2, 2
+ // CHECK-NEXT: ( 2, 2, 2, 2
+ // CHECK-NEXT: ( 2, 2, 2, 2
+ // CHECK-NEXT: ( 2, 2, 2, 2
+ scf.for %i = %c0 to %size_of_two_tiles step %svl_s {
+ %tileslice = vector.load %mem1[%i] : memref<?xi32>, vector<[4]xi32>
+
+ llvm.call @printOpen() : () -> ()
+ scf.for %i2 = %c0 to %svl_s step %c1_index {
+ %elem = vector.extractelement %tileslice[%i2 : index] : vector<[4]xi32>
+ %elem_i64 = llvm.zext %elem : i32 to i64
+ llvm.call @printI64(%elem_i64) : (i64) -> ()
+ %last_i = arith.subi %svl_s, %c1_index : index
+ %isNotLastIter = arith.cmpi ult, %i2, %last_i : index
+ scf.if %isNotLastIter {
+ llvm.call @printComma() : () -> ()
+ }
+ }
+ llvm.call @printClose() : () -> ()
+ llvm.call @printNewline() : () -> ()
+ }
+
+ // Load tile 1 from memory
+ %za0_s = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+
+ // Load tile 2 from memory
+ %za1_s = vector.load %mem1[%size_of_tile] : memref<?xi32>, vector<[4]x[4]xi32>
+
+ // Allocate new memory to store tiles to
+ %mem2 = memref.alloca(%size_of_two_tiles) : memref<?xi32>
+
+ // Zero new memory
+ scf.for %i = %c0 to %size_of_two_tiles step %c1_index {
+ memref.store %c0_i32, %mem2[%i] : memref<?xi32>
+ }
+
+ // Stores tiles back to (new) memory in reverse order
+
+ // Store tile 2 to memory
+ vector.store %za1_s, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+
+ // Store tile 1 to memory
+ vector.store %za0_s, %mem2[%size_of_tile] : memref<?xi32>, vector<[4]x[4]xi32>
+
+ // Dump "mem2" and check the tiles were stored in reverse order. The smallest
+ // SVL is 128-bits so the tiles will be at least 4x4xi32.
+ //
+ // CHECK: TILE BEGIN
+ // CHECK-NEXT: ( 2, 2, 2, 2
+ // CHECK-NEXT: ( 2, 2, 2, 2
+ // CHECK-NEXT: ( 2, 2, 2, 2
+ // CHECK-NEXT: ( 2, 2, 2, 2
+ // CHECK: TILE END
+ // CHECK-NEXT: TILE BEGIN
+ // CHECK-NEXT: ( 1, 1, 1, 1
+ // CHECK-NEXT: ( 1, 1, 1, 1
+ // CHECK-NEXT: ( 1, 1, 1, 1
+ // CHECK-NEXT: ( 1, 1, 1, 1
+ // CHECK: TILE END
+ func.call @printTileBegin() : () -> ()
+ scf.for %i = %c0 to %size_of_two_tiles step %svl_s {
+ %av = vector.load %mem2[%i] : memref<?xi32>, vector<[4]xi32>
+
+ llvm.call @printOpen() : () -> ()
+ scf.for %i2 = %c0 to %svl_s step %c1_index {
+ %elem = vector.extractelement %av[%i2 : index] : vector<[4]xi32>
+ %elem_i64 = llvm.zext %elem : i32 to i64
+ llvm.call @printI64(%elem_i64) : (i64) -> ()
+ %last_i = arith.subi %svl_s, %c1_index : index
+ %isNotLastIter = arith.cmpi ult, %i2, %last_i : index
+ scf.if %isNotLastIter {
+ llvm.call @printComma() : () -> ()
+ }
+ }
+ llvm.call @printClose() : () -> ()
+ llvm.call @printNewline() : () -> ()
+
+ %tileSizeMinusStep = arith.subi %size_of_tile, %svl_s : index
+ %isNextTile = arith.cmpi eq, %i, %tileSizeMinusStep : index
+ scf.if %isNextTile {
+ func.call @printTileEnd() : () -> ()
+ func.call @printTileBegin() : () -> ()
+ }
+ }
+ func.call @printTileEnd() : () -> ()
+
+ return %c0_i32 : i32
+}
+
+llvm.mlir.global internal constant @str_tile_begin("TILE BEGIN\0A")
+llvm.mlir.global internal constant @str_tile_end("TILE END\0A")
More information about the Mlir-commits
mailing list