[PATCH] D155306: [mlir][ArmSME] Add tile load op and extend tile store tile size support

Andrzej Warzynski via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Thu Jul 20 01:32:51 PDT 2023


awarzynski added a comment.

Thanks for the updates, I've left a few more nits/suggestions, but nothing major.



================
Comment at: mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h:9
+//
+// This header file defines prototypes for various transformation utilities for
+// the ArmSME dialect. These are not passes by themselves but are used
----------------
[nit]


================
Comment at: mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp:114-118
+/// Returns `offset` if memref is rank 2, otherwise adjusts `offset` by the
+/// number of elements in a vector of SVL bits.
+Value getOffset(MemRefType memRefType, Value offset, Value vscale,
+                Value minElems, Location loc,
+                ConversionPatternRewriter &rewriter) {
----------------
Right, IIUC, this is something like:

``` Scale the memory offset, i.e. `vnum`, if needed:
 * for rank 2 memrefs, `getStridedElementPtr`does the calculation for us, so just return `vnum`. 
 * for rank 1 memrefs, assume row-major storage and scale by the effective vector length.
```

Btw, this makes lots of sense, I just would like for us to be very clear about the meaning of `offset` and `vnum` in this context. The latter name. imho, includes a bit of helpful context, hence suggestion to rename. In the context `getStridedElementPtr`, `offset` probably makes more sense. `getOffset` also feels a bit too generic 🤔 .


================
Comment at: mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp:177
+    // Create 'arm_sme.get_tile_id' op.
+    unsigned width = vType.getElementType().getIntOrFloatBitWidth();
+    auto tile = rewriter.create<arm_sme::GetTileID>(
----------------
[nit] Suggestion for a more descriptive name (it's a rather key bit in the SME logic)


================
Comment at: mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp:108-109
+
+/// Returns `offset` if memref is rank 2, otherwise adjusts `offset` by SVL<t>
+/// bytes.
+Value getOffset(MemRefType memRefType, Value offset, Value vscale,
----------------
c-rhodes wrote:
> awarzynski wrote:
> > What "offset" is it? Why do we need to adjust it in the 1-D case and just return as is in 2-D?
> > 
> > In the 1-D case, is it:
> > ```
> > offet * vscale * minElems
> > ```
> > ? And is `minElems` the minimum number of elements in a scalable vector? So basically the "base size of an SVE vector"?
> > What "offset" is it? Why do we need to adjust it in the 1-D case and just return as is in 2-D?
> 
> The offset to the load or store pointer. In the 2D case `getStridedElementPtr` does the arithmetic for us, but in the 1D case we have to do it ourselves.
> 
> > In the 1-D case, is it:
> > ```
> > offet * vscale * minElems
> > ```
> 
> Yeah, and `offset` is `vnum` so it's `vnum * vscale * minElems`. So the offset is the number of elements for a given type in a vector of SVL bits (SVLt), and this is scaled by `vnum`. So the base would get incremented a tile vector at a time.
> 
> > ? And is `minElems` the minimum number of elements in a scalable vector? So basically the "base size of an SVE vector"?
> 
> Yeah, so `128 / esize`. Perhaps `getMinNumElts` would be better implemented like that rather than with a switch actually.
> 
> As for where supporting for both 1D and 2D memrefs came from, I initially started with these integration tests that dump ZA but could get it to work with 2D memrefs: https://gist.github.com/c-rhodes/1e9f2d8fd0ca3c6539f167e08079f6ab
> 
> I found those tests useful for verification but since the output varies depending on the runtime VL we cant add these tests.
> 
> I found those tests useful for verification but since the output varies depending on the runtime VL we cant add these tests.

Yeah, it would be nice to include them. Wouldn't it be possible to add `CHECK` lines that would assume minimum possible VL for each type? Not in this patch though - it's quite large as is.


================
Comment at: mlir/lib/Dialect/ArmSME/Utils/Utils.cpp:27
+
+bool mlir::arm_sme::isValidTileElementType(Type type) {
+  // TODO: add support for i128.
----------------
[nit] Naming is hard


================
Comment at: mlir/lib/Dialect/ArmSME/Utils/Utils.cpp:34
+
+bool mlir::arm_sme::isSMETileLikeVectorType(VectorType vType) {
+  if ((vType.getRank() != 2) && vType.allDimsScalable())
----------------
[nit] Naming is hard


================
Comment at: mlir/test/Dialect/ArmSME/roundtrip.mlir:197
 
+func.func @arm_sme_tile_load_i8(%memref : memref<?x?xi8>) -> () {
+  // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xi8>, vector<[16]x[16]xi8>
----------------
For consistency with the `tile_store` at the bottom.


================
Comment at: mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir:30-31
+// CHECK-NEXT:   "arm_sme.intr.st1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+// CHECK-NEXT: }
+// CHECK-NEXT: return
 func.func @transfer_write_2d_zero_i8(%arg0 : memref<?x?xi8>) {
----------------
[nit] Not needed


================
Comment at: mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir:74
+
+// CHECK-LABEL: @vector_load_i8_rank_1_memref(
+// CHECK-SAME:                                %[[ARG0:.*]]: memref<?xi8>)
----------------
[nit] Just to make it clearer what's distinct about this test


================
Comment at: mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir:110-134
+// CHECK-LABEL: @vector_load_i16(
+// CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xi16>)
+// CHECK-NEXT: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?x?xi16> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK-NEXT: %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i16
+// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-NEXT: %[[MIN_ZA_VECTORS:.*]] = arith.constant 8 : index
----------------
This and the following tests differ only in a few details that are tricky to spot. I am thinking that perhaps we should trim these to highlight the differences? That would be more in line with https://mlir.llvm.org/getting_started/TestingGuide/:
>     Tests should be minimal, and only check what is absolutely necessary.
>
> This means that anything in the output that is not core to the functionality that you are testing should not be present in a CHECK line. 

The 2 tests above are sufficient to test the other nuances.


================
Comment at: mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir:77
+
+  // Verify "mem1" == "mem2"
+  %init_1 = arith.constant 1 : i64
----------------
Having a "negative" would be good too (i.e. verify that `mem1 != mem2`).


================
Comment at: mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir:96
+  // CHECK-ZA0_D-NEXT: 1
+  vector.print %mul_reduce : i64
+
----------------
How about printing `mem1` and `mem2` and checking something this:
```
  //  CHECK:  0.1, 0.1 {{.*}}
  //  CHECK: 1.1, 1.1 {{.*}}
  //  CHECK: <some clever string printed after all of mem1 has been printed>
```


CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D155306/new/

https://reviews.llvm.org/D155306



More information about the llvm-commits mailing list