[Mlir-commits] [mlir] 65a6be5 - [mlir][ArmSME] Use memref indices for load and store

Cullen Rhodes llvmlistbot at llvm.org
Thu Aug 3 01:50:26 PDT 2023


Author: Cullen Rhodes
Date: 2023-08-03T08:50:12Z
New Revision: 65a6be5de97adec09534d37019e384e374069ce7

URL: https://github.com/llvm/llvm-project/commit/65a6be5de97adec09534d37019e384e374069ce7
DIFF: https://github.com/llvm/llvm-project/commit/65a6be5de97adec09534d37019e384e374069ce7.diff

LOG: [mlir][ArmSME] Use memref indices for load and store

This patch extends the ArmSME load and store op lowering to use the
memref indices. An integration test that loads two 32-bit element ZA
tiles from memory and stores them back to memory in reverse order to
verify this is added.

Depends on D156467 D156558

Reviewed By: awarzynski, dcaballe

Differential Revision: https://reviews.llvm.org/D156689

Added: 
    

Modified: 
    mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
    mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
    mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
    mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
    mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index e143726cf234f2..4028a7ad0870b5 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -26,6 +26,28 @@ namespace mlir {
 using namespace mlir;
 
 namespace {
+/// Adjusts `indices` as follows for a given tile slice and returns them in
+/// `outIndices`:
+///   rank 1: (indices[0] + (tileSliceIndex * tileSliceNumElts))
+///   rank 2: (indices[0] + tileSliceIndex, indices[1])
+void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex,
+                      Value tileSliceNumElts,
+                      SmallVectorImpl<Value> &outIndices, Location loc,
+                      PatternRewriter &rewriter) {
+  assert((rank == 1 || rank == 2) && "memref has unexpected rank!");
+
+  auto tileSliceOffset = tileSliceIndex;
+  if (rank == 1)
+    tileSliceOffset =
+        rewriter.create<arith::MulIOp>(loc, tileSliceOffset, tileSliceNumElts);
+
+  auto baseIndexPlusTileSliceOffset =
+      rewriter.create<arith::AddIOp>(loc, indices[0], tileSliceOffset);
+  outIndices.push_back(baseIndexPlusTileSliceOffset);
+
+  if (rank == 2)
+    outIndices.push_back(indices[1]);
+}
 
 /// Lower `arm_sme.tile_load` to a loop over the tile slices and load each slice
 /// using `arm_sme.load_tile_slice`.
@@ -77,6 +99,9 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
     auto vscale =
         rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
     auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    // This describes both the number of ZA tile slices and the number of
+    // elements in a vector of SVL bits for a given element type (SVL_B, SVL_H,
+    // ..., SVL_Q).
     auto numTileSlices =
         rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
     auto forOp =
@@ -84,13 +109,16 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
 
     rewriter.setInsertionPointToStart(forOp.getBody());
 
+    // Create 'arm_sme.load_tile_slice' to load tile slice from memory into
+    // tile.
+    SmallVector<Value> memrefIndices;
     auto tileSliceIndex = forOp.getInductionVar();
-    // TODO: use indices
-    // Create 'arm_sme.load_tile_slice' to load tile slice from
-    // memory into tile.
-    rewriter.create<arm_sme::LoadTileSliceOp>(
-        loc, tileType, tileLoadOp.getBase(), tile, tileSliceIndex,
-        tileSliceIndex);
+    getMemrefIndices(tileLoadOp.getIndices(),
+                     tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
+                     numTileSlices, memrefIndices, loc, rewriter);
+    rewriter.create<arm_sme::LoadTileSliceOp>(loc, tileType,
+                                              tileLoadOp.getBase(), tile,
+                                              memrefIndices, tileSliceIndex);
 
     rewriter.setInsertionPointAfter(forOp);
 
@@ -139,6 +167,9 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
     auto vscale =
         rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
     auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    // This describes both the number of ZA tile slices and the number of
+    // elements in a vector of SVL bits for a given element type (SVL_B, SVL_H,
+    // ..., SVL_Q).
     auto numTileSlices =
         rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
     auto forOp =
@@ -146,11 +177,14 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
 
     rewriter.setInsertionPointToStart(forOp.getBody());
 
+    SmallVector<Value> memrefIndices;
     auto tileSliceIndex = forOp.getInductionVar();
-    // TODO: use indices
+    getMemrefIndices(tileStoreOp.getIndices(),
+                     tileStoreOp.getMemRefType().getRank(), tileSliceIndex,
+                     numTileSlices, memrefIndices, loc, rewriter);
     rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
         tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex,
-        tileStoreOp.getBase(), tileSliceIndex);
+        tileStoreOp.getBase(), memrefIndices);
 
     return success();
   }

diff  --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index d61bde971647dc..e1df09ff9e0758 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -111,33 +111,6 @@ Value castTileIDToI32(Value tile, Location loc,
   return tile;
 }
 
-/// Returns the following
-/// * for rank 2 memrefs `tileSliceIndex`, since `getStridedElementPtr` does
-///   the arithmetic.
-/// * for rank 1 memrefs `tileSliceIndex * tileSliceNumElts`, adjusting the
-///   index by the number of elements in a vector of SVL bits.
-/// * otherwise throws an unreachable error.
-Value getTileSlicePtrIndex(unsigned rank, Value tileSliceIndex,
-                           Value tileSliceNumElts, Location loc,
-                           ConversionPatternRewriter &rewriter) {
-  assert((rank == 1 || rank == 2) && "memref has unexpected rank!");
-
-  auto tileSliceIndexI64 = rewriter.create<arith::IndexCastUIOp>(
-      loc, rewriter.getI64Type(), tileSliceIndex);
-
-  if (rank == 1) {
-    auto tileSliceNumEltsI64 = rewriter.create<arith::IndexCastUIOp>(
-        loc, rewriter.getI64Type(), tileSliceNumElts);
-    return rewriter.create<arith::MulIOp>(loc, tileSliceIndexI64,
-                                          tileSliceNumEltsI64);
-  }
-
-  if (rank == 2)
-    return tileSliceIndexI64;
-
-  llvm_unreachable("memref has unexpected rank!");
-}
-
 /// Lower `arm_sme.load_tile_slice` to SME intrinsics.
 struct LoadTileSliceToArmSMELowering
     : public ConvertOpToLLVMPattern<arm_sme::LoadTileSliceOp> {
@@ -159,25 +132,11 @@ struct LoadTileSliceToArmSMELowering
         loc, rewriter.getIntegerType(tileElementWidth),
         loadTileSliceOp.getTile());
 
-    auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
-        loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
-    auto vscale =
-        rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
-    // This describes both the number of ZA tile slices and the number of
-    // elements in a vector of SVL bits for a given element type (SVL_B, SVL_H,
-    // ..., SVL_Q).
-    auto numTileSlices =
-        rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
+    Value ptr = this->getStridedElementPtr(loc, loadTileSliceOp.getMemRefType(),
+                                           adaptor.getBase(),
+                                           adaptor.getIndices(), rewriter);
 
-    // Create 'arm_sme.intr.ld1*.horiz' intrinsic to load ZA tile slice.
-    auto memRefType = loadTileSliceOp.getMemRefType();
     auto tileSlice = loadTileSliceOp.getTileSliceIndex();
-    // TODO: The 'indices' argument for the 'base' memref is currently ignored,
-    // 'tileSliceIndex' should be added to 'indices[0]'.
-    Value tileSliceIndex = getTileSlicePtrIndex(memRefType.getRank(), tileSlice,
-                                                numTileSlices, loc, rewriter);
-    Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.getBase(),
-                                           {tileSliceIndex}, rewriter);
 
     // Cast tile slice to i32 for intrinsic.
     auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
@@ -192,6 +151,7 @@ struct LoadTileSliceToArmSMELowering
     auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
 
     auto tileI32 = castTileIDToI32(tile, loc, rewriter);
+    // Create 'arm_sme.intr.ld1*.horiz' intrinsic to load ZA tile slice.
     switch (tileElementWidth) {
     default:
       llvm_unreachable("unexpected element type!");
@@ -243,25 +203,12 @@ struct StoreTileSliceToArmSMELowering
         loc, rewriter.getIntegerType(tileElementWidth),
         storeTileSliceOp.getTile());
 
-    auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
-        loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
-    auto vscale =
-        rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
-    // This describes both the number of ZA tile slices and the number of
-    // elements in a vector of SVL bits for a given element type (SVL_B, SVL_H,
-    // ..., SVL_Q).
-    auto numTileSlices =
-        rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
-
     // Create 'arm_sme.intr.st1*.horiz' intrinsic to store ZA tile slice.
-    auto memRefType = storeTileSliceOp.getMemRefType();
+    Value ptr = this->getStridedElementPtr(
+        loc, storeTileSliceOp.getMemRefType(), adaptor.getBase(),
+        adaptor.getIndices(), rewriter);
+
     auto tileSlice = storeTileSliceOp.getTileSliceIndex();
-    // TODO: The 'indices' argument for the 'base' memref is currently ignored,
-    // 'tileSliceIndex' should be added to 'indices[0]'.
-    Value tileSliceIndex = getTileSlicePtrIndex(memRefType.getRank(), tileSlice,
-                                                numTileSlices, loc, rewriter);
-    Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.getBase(),
-                                           {tileSliceIndex}, rewriter);
 
     // Cast tile slice to i32 for intrinsic.
     auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(

diff  --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index b64c79663038c0..9ab1d79794d765 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -2,15 +2,16 @@
 
 // CHECK-LABEL: func.func @arm_sme_tile_load(
 // CHECK-SAME:                               %[[SRC:.*]]: memref<?x?xi32>) {
-// CHECK-NEXT:    %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
-// CHECK-NEXT:    %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
+// CHECK-DAG:     %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
+// CHECK-DAG:     %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
 // CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
 // CHECK-DAG:     %[[VSCALE:.*]] = vector.vscale
 // CHECK-NEXT:    %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
 // CHECK-NEXT:    scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
-// CHECK-NEXT:      arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[TILE_SLICE_INDEX]]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]x[4]xi32>
+// CHECK-NEXT:      %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
+// CHECK-NEXT:      arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]x[4]xi32>
 func.func @arm_sme_tile_load(%src : memref<?x?xi32>) {
   %c0 = arith.constant 0 : index
   %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
@@ -28,7 +29,8 @@ func.func @arm_sme_tile_load(%src : memref<?x?xi32>) {
 // CHECK-DAG:     %[[VSCALE:.*]] = vector.vscale
 // CHECK:         %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
 // CHECK:         scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
-// CHECK:           arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[DEST]]{{\[}}%[[TILE_SLICE_INDEX]]] : memref<?x?xi32>, vector<[4]x[4]xi32>
+// CHECK:           %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
+// CHECK:           arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[DEST]]{{\[}}%[[OFFSET]], %[[C0]]] : memref<?x?xi32>, vector<[4]x[4]xi32>
 func.func @arm_sme_tile_store(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
   %c0 = arith.constant 0 : index
   arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>

diff  --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
index 402da661cfee44..ddd55319d347f3 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
@@ -8,17 +8,19 @@
 // CHECK-DAG:  %[[MIN_SVL_B:.*]] = arith.constant 16 : index
 // CHECK-DAG:  %[[C255:.*]] = arith.constant 255 : i32
 // CHECK-DAG:  %[[PTRUE_ALL:.*]] = arith.constant dense<true> : vector<[16]xi1>
+// CHECK-DAG:  %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64
 // CHECK-DAG:  "arm_sme.intr.zero"(%[[C255]]) : (i32) -> ()
 // CHECK-DAG:  %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8
 // CHECK-DAG:  %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
 // CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
 // CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE_IDX]], %[[MIN_SVL_B]] : index
 // CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] {
-// CHECK:        %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64
+// CHECK:        %[[TILE_SLICE_I64:.*]] = builtin.unrealized_conversion_cast %[[TILE_SLICE]] : index to i64
 // CHECK-NEXT:   %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK-NEXT:   %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK-NEXT:   %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]]  : i64
-// CHECK-NEXT:   %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF0]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
+// CHECK-NEXT:   %[[OFF1:.*]] = llvm.add %[[OFF0]], %[[C0_I64]]  : i64
+// CHECK-NEXT:   %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
 // CHECK-NEXT:   %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32
 // CHECK-NEXT:   %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
 // CHECK-NEXT:   "arm_sme.intr.st1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
@@ -31,32 +33,41 @@ func.func @transfer_write_2d_zero_i8(%arg0 : memref<?x?xi8>) {
 
 // -----
 
-// CHECK-LABEL: @vector_load_i8(
-// CHECK-SAME:                  %[[ARG0:.*]]: memref<?x?xi8>)
+// Load an 8-bit tile from a rank 2 memref with a non-zero offset for the first
+// memref index. This verifies the offset is preserved when materializing the
+// loop of tile slice loads.
+
+// CHECK-LABEL: @vector_load_i8_with_offset(
+// CHECK-SAME:                              %[[ARG0:.*]]: memref<?x?xi8>)
 // CHECK-DAG:  %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK-DAG:  %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8
 // CHECK-DAG:  %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
 // CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:  %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:  %[[C123:.*]] = arith.constant 123 : index
 // CHECK-DAG:  %[[MIN_SVL_B:.*]] = arith.constant 16 : index
 // CHECK-DAG:  %[[PTRUE_ALL:.*]] = arith.constant dense<true> : vector<[16]xi1>
+// CHECK-DAG:  %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64
 // CHECK-DAG:  %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
 // CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
 // CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE_IDX]], %[[MIN_SVL_B]] : index
 // CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] {
-// CHECK:        %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64
+// CHECK-NEXT:   %[[TILE_SLICE_PLUS_OFF0:.*]] = arith.addi %[[TILE_SLICE]], %[[C123]] : index
+// CHECK-NEXT:   %[[TILE_SLICE_PLUS_OFF0_I64:.*]] = builtin.unrealized_conversion_cast %[[TILE_SLICE_PLUS_OFF0]] : index to i64
 // CHECK-NEXT:   %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK-NEXT:   %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK-NEXT:   %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]]  : i64
-// CHECK-NEXT:   %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF0]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
+// CHECK-NEXT:   %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_PLUS_OFF0_I64]], %[[STRIDE0]]  : i64
+// CHECK-NEXT:   %[[OFF1:.*]] = llvm.add %[[OFF0]], %[[C0_I64]]  : i64
+// CHECK-NEXT:   %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
 // CHECK-NEXT:   %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32
 // CHECK-NEXT:   %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
 // CHECK-NEXT:   "arm_sme.intr.ld1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
 // CHECK-NEXT: }
 // CHECK-NEXT: return  %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8>
-func.func @vector_load_i8(%arg0 : memref<?x?xi8>) -> vector<[16]x[16]xi8> {
+func.func @vector_load_i8_with_offset(%arg0 : memref<?x?xi8>) -> vector<[16]x[16]xi8> {
   %c0 = arith.constant 0 : index
-  %tile = vector.load %arg0[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+  %c123 = arith.constant 123 : index
+  %tile = vector.load %arg0[%c123, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
   return %tile : vector<[16]x[16]xi8>
 }
 
@@ -75,14 +86,10 @@ func.func @vector_load_i8(%arg0 : memref<?x?xi8>) -> vector<[16]x[16]xi8> {
 // CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
 // CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE_IDX]], %[[MIN_SVL_B]] : index
 // CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] {
-// CHECK-NEXT:   %[[VSCALE_1:.*]] = "llvm.intr.vscale"() : () -> i64
-// CHECK-NEXT:   %[[VSCALE_IDX_1:.*]] = builtin.unrealized_conversion_cast %[[VSCALE_1]] : i64 to index
-// CHECK-NEXT:   %[[SVL_B_1:.*]] = arith.muli %[[VSCALE_IDX_1]], %[[MIN_SVL_B]] : index
-// CHECK-NEXT:   %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64
-// CHECK-NEXT:   %[[SVL_B_I64:.*]] = arith.index_castui %[[SVL_B_1]] : index to i64
-// CHECK-NEXT:   %[[TILE_SLICE_IDX:.*]] = arith.muli %[[TILE_SLICE_I64]], %[[SVL_B_I64]] : i64
+// CHECK-NEXT:   %[[TILE_SLICE_IDX:.*]] = arith.muli %[[TILE_SLICE]], %[[SVL_B]] : index
+// CHECK-NEXT:   %[[TILE_SLICE_IDX_I64:.*]] = builtin.unrealized_conversion_cast %[[TILE_SLICE_IDX]] : index to i64
 // CHECK-NEXT:   %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
-// CHECK-NEXT:   %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[TILE_SLICE_IDX]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
+// CHECK-NEXT:   %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[TILE_SLICE_IDX_I64]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
 // CHECK-NEXT:   %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32
 // CHECK-NEXT:   %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
 // CHECK-NEXT:   "arm_sme.intr.ld1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
@@ -218,17 +225,19 @@ func.func @vector_load_f64(%arg0 : memref<?x?xf64>) -> vector<[2]x[2]xf64> {
 // CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:  %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG:  %[[MIN_SVL_B:.*]] = arith.constant 16 : index
+// CHECK-DAG:  %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64
 // CHECK-DAG:  %[[PTRUE_ALL:.*]] = arith.constant dense<true> : vector<[16]xi1>
 // CHECK-DAG:  %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
 // CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
 // CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE_IDX]], %[[MIN_SVL_B]] : index
 // CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] {
+// CHECK:        %[[TILE_SLICE_I64:.*]] = builtin.unrealized_conversion_cast %[[TILE_SLICE]] : index to i64
 // CHECK-NEXT:   %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[16]x[16]xi8> to i8
-// CHECK:        %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64
 // CHECK-NEXT:   %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK-NEXT:   %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK-NEXT:   %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]]  : i64
-// CHECK-NEXT:   %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF0]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
+// CHECK-NEXT:   %[[OFF1:.*]] = llvm.add %[[OFF0]], %[[C0_I64]]  : i64
+// CHECK-NEXT:   %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
 // CHECK-NEXT:   %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32
 // CHECK-NEXT:   %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i8 to i32
 // CHECK-NEXT:   "arm_sme.intr.st1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
index 0da4b1cf319e6d..72cfc10e7b7cf5 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
@@ -1,21 +1,31 @@
-// RUN: mlir-opt %s -enable-arm-streaming="mode=locally enable-za" \
-// RUN:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
-// RUN:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
-// RUN:   -allocate-arm-sme-tiles -test-lower-to-llvm | \
-// RUN: mlir-translate -mlir-to-llvmir | \
-// RUN: %lli_aarch64_cmd --march=aarch64 --mattr="+sve,+sme" \
-// RUN:   --entry-function=za0_d_f64 \
-// RUN:   --dlopen=%mlir_native_utils_lib_dir/libmlir_c_runner_utils%shlibext | \
-// RUN: FileCheck %s --check-prefix=CHECK-ZA0_D
-
-// Integration test demonstrating load/store to/from SME ZA tile.
+// DEFINE: %{entry_point} = za0_d_f64
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   -enable-arm-streaming="mode=locally enable-za" \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE: %{run} = %mcr_aarch64_cmd \
+// DEFINE:  -march=aarch64 -mattr=+sve,+sme \
+// DEFINE:  -e %{entry_point} -entry-point-result=i32 \
+// DEFINE:  -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+
+// RUN: %{compile} | %{run} | FileCheck %s --check-prefix=CHECK-ZA0_D
+
+// REDEFINE: %{entry_point} = load_store_two_za_s_tiles
+// RUN: %{compile} | %{run} | FileCheck %s
+
+// Integration tests demonstrating load/store to/from SME ZA tile.
 
 llvm.func @printF64(f64)
+llvm.func @printI64(i64)
 llvm.func @printOpen()
 llvm.func @printClose()
 llvm.func @printComma()
 llvm.func @printNewline()
+llvm.func @printCString(!llvm.ptr<i8>)
 
+// This test verifies a 64-bit element ZA with FP64 data is correctly
+// loaded/stored to/from memory.
 func.func @za0_d_f64() -> i32 {
   %c0 = arith.constant 0 : index
   %c0_f64 = arith.constant 0.0 : f64
@@ -191,3 +201,174 @@ func.func @za0_d_f64() -> i32 {
   %c0_i32 = arith.constant 0 : i32
   return %c0_i32 : i32
 }
+
+func.func @printTileBegin() {
+  %0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr<array<11 x i8>>
+  %1 = llvm.mlir.constant(0 : index) : i64
+  %2 = llvm.getelementptr %0[%1, %1]
+    : (!llvm.ptr<array<11 x i8>>, i64, i64) -> !llvm.ptr<i8>
+  llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
+  return
+}
+
+func.func @printTileEnd() {
+  %0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr<array<9 x i8>>
+  %1 = llvm.mlir.constant(0 : index) : i64
+  %2 = llvm.getelementptr %0[%1, %1]
+    : (!llvm.ptr<array<9 x i8>>, i64, i64) -> !llvm.ptr<i8>
+  llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
+  return
+}
+
+// This test loads two 32-bit element ZA tiles from memory and stores them back
+// to memory in reverse order. This verifies the memref indices for the vector
+// load and store are correctly preserved since the second tile is offset from
+// the first tile.
+func.func @load_store_two_za_s_tiles() -> i32 {
+  %c0 = arith.constant 0 : index
+  %c0_i32 = arith.constant 0 : i32
+  %c1_i32 = arith.constant 1 : i32
+  %c2_i32 = arith.constant 2 : i32
+  %c1_index = arith.constant 1 : index
+  %c2_index = arith.constant 2 : index
+
+  %min_elts_s = arith.constant 4 : index
+  %vscale = vector.vscale
+
+  // "svl" refers to the Streaming Vector Length and "svl_s" can mean either:
+  // * the number of 32-bit elements in a vector of SVL bits.
+  // * the number of tile slices (1d vectors) in a 32-bit element tile.
+  %svl_s = arith.muli %min_elts_s, %vscale : index
+
+  // Allocate memory for two 32-bit element tiles.
+  %size_of_tile = arith.muli %svl_s, %svl_s : index
+  %size_of_two_tiles = arith.muli %size_of_tile, %c2_index : index
+  %mem1 = memref.alloca(%size_of_two_tiles) : memref<?xi32>
+
+  // Fill memory that tile 1 will be loaded from with '1' and '2' for tile 2.
+  //
+  // For example, assuming an SVL of 128-bits and two 4x4xi32 tiles:
+  //
+  // tile 1
+  //
+  //   1, 1, 1, 1
+  //   1, 1, 1, 1
+  //   1, 1, 1, 1
+  //   1, 1, 1, 1
+  //
+  // tile 2
+  //
+  //   2, 2, 2, 2
+  //   2, 2, 2, 2
+  //   2, 2, 2, 2
+  //   2, 2, 2, 2
+  //
+  scf.for %i = %c0 to %size_of_two_tiles step %svl_s {
+    %isFirstTile = arith.cmpi ult, %i, %size_of_tile : index
+    %val = scf.if %isFirstTile -> i32 {
+      scf.yield %c1_i32 : i32
+    } else {
+      scf.yield %c2_i32 : i32
+    }
+    %splat_val = vector.broadcast %val : i32 to vector<[4]xi32>
+    vector.store %splat_val, %mem1[%i] : memref<?xi32>, vector<[4]xi32>
+  }
+
+  // Dump "mem1". The smallest SVL is 128-bits so each tile will be at least
+  // 4x4xi32.
+  //
+  // CHECK:      ( 1, 1, 1, 1
+  // CHECK-NEXT: ( 1, 1, 1, 1
+  // CHECK-NEXT: ( 1, 1, 1, 1
+  // CHECK-NEXT: ( 1, 1, 1, 1
+  // CHECK:      ( 2, 2, 2, 2
+  // CHECK-NEXT: ( 2, 2, 2, 2
+  // CHECK-NEXT: ( 2, 2, 2, 2
+  // CHECK-NEXT: ( 2, 2, 2, 2
+  scf.for %i = %c0 to %size_of_two_tiles step %svl_s {
+    %tileslice = vector.load %mem1[%i] : memref<?xi32>, vector<[4]xi32>
+
+    llvm.call @printOpen() : () -> ()
+    scf.for %i2 = %c0 to %svl_s step %c1_index {
+      %elem = vector.extractelement %tileslice[%i2 : index] : vector<[4]xi32>
+      %elem_i64 = llvm.zext %elem : i32 to i64
+      llvm.call @printI64(%elem_i64) : (i64) -> ()
+      %last_i = arith.subi %svl_s, %c1_index : index
+      %isNotLastIter = arith.cmpi ult, %i2, %last_i : index
+      scf.if %isNotLastIter {
+        llvm.call @printComma() : () -> ()
+      }
+    }
+    llvm.call @printClose() : () -> ()
+    llvm.call @printNewline() : () -> ()
+  }
+
+  // Load tile 1 from memory
+  %za0_s = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+
+  // Load tile 2 from memory
+  %za1_s = vector.load %mem1[%size_of_tile] : memref<?xi32>, vector<[4]x[4]xi32>
+
+  // Allocate new memory to store tiles to
+  %mem2 = memref.alloca(%size_of_two_tiles)  : memref<?xi32>
+
+  // Zero new memory
+  scf.for %i = %c0 to %size_of_two_tiles step %c1_index {
+    memref.store %c0_i32, %mem2[%i] : memref<?xi32>
+  }
+
+  // Stores tiles back to (new) memory in reverse order
+
+  // Store tile 2 to memory
+  vector.store %za1_s, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+
+  // Store tile 1 to memory
+  vector.store %za0_s, %mem2[%size_of_tile] : memref<?xi32>, vector<[4]x[4]xi32>
+
+  // Dump "mem2" and check the tiles were stored in reverse order. The smallest
+  // SVL is 128-bits so the tiles will be at least 4x4xi32.
+  //
+  // CHECK:      TILE BEGIN
+  // CHECK-NEXT: ( 2, 2, 2, 2
+  // CHECK-NEXT: ( 2, 2, 2, 2
+  // CHECK-NEXT: ( 2, 2, 2, 2
+  // CHECK-NEXT: ( 2, 2, 2, 2
+  // CHECK:      TILE END
+  // CHECK-NEXT: TILE BEGIN
+  // CHECK-NEXT: ( 1, 1, 1, 1
+  // CHECK-NEXT: ( 1, 1, 1, 1
+  // CHECK-NEXT: ( 1, 1, 1, 1
+  // CHECK-NEXT: ( 1, 1, 1, 1
+  // CHECK:      TILE END
+  func.call @printTileBegin() : () -> ()
+  scf.for %i = %c0 to %size_of_two_tiles step %svl_s {
+    %av = vector.load %mem2[%i] : memref<?xi32>, vector<[4]xi32>
+
+    llvm.call @printOpen() : () -> ()
+    scf.for %i2 = %c0 to %svl_s step %c1_index {
+      %elem = vector.extractelement %av[%i2 : index] : vector<[4]xi32>
+      %elem_i64 = llvm.zext %elem : i32 to i64
+      llvm.call @printI64(%elem_i64) : (i64) -> ()
+      %last_i = arith.subi %svl_s, %c1_index : index
+      %isNotLastIter = arith.cmpi ult, %i2, %last_i : index
+      scf.if %isNotLastIter {
+        llvm.call @printComma() : () -> ()
+      }
+    }
+    llvm.call @printClose() : () -> ()
+    llvm.call @printNewline() : () -> ()
+
+    %tileSizeMinusStep = arith.subi %size_of_tile, %svl_s : index
+    %isNextTile = arith.cmpi eq, %i, %tileSizeMinusStep : index
+    scf.if %isNextTile {
+      func.call @printTileEnd() : () -> ()
+      func.call @printTileBegin() : () -> ()
+    }
+  }
+  func.call @printTileEnd() : () -> ()
+
+  return %c0_i32 : i32
+}
+
+llvm.mlir.global internal constant @str_tile_begin("TILE BEGIN\0A")
+llvm.mlir.global internal constant @str_tile_end("TILE END\0A")


        


More information about the Mlir-commits mailing list