[PATCH] D155306: [mlir][ArmSME] Add tile load op and extend tile store tile size support

Cullen Rhodes via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 19 09:42:18 PDT 2023


c-rhodes marked 12 inline comments as done.
c-rhodes added a comment.

> Overall looks good, thanks! We definitely need to take care of the loop materialisation after this change (i.e. move it higher up the compilation stack).

Thanks for the comments! I agree w.r.t. loop materialization, currently looking into that separately from this.



================
Comment at: mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h:24
+
+/// Utility to return bitwidth of type which should be an integer or float.
+unsigned getWidth(Type type);
----------------
awarzynski wrote:
> 
> 

I discovered `getIntOrFloatBitWidth` so I've removed this and replaced it with that.


================
Comment at: mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp:102-120
+    auto vType = loadOrStoreOp.getVectorType();
+    if ((vType.getRank() != 2) && vType.allDimsScalable())
+      return failure();
+
+    // TODO: add support for i128.
+    auto elemType = vType.getElementType();
+    if ((elemType != rewriter.getI8Type()) &&
----------------
awarzynski wrote:
> I'd move this to an utility function - I expect that we'll be needing this for other Ops as well.
> I'd move this to an utility function - I expect that we'll be needing this for other Ops as well.

Moved to `arm_sme::isSMETileLikeVectorType`.


================
Comment at: mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp:133-134
+  patterns.add<TransferWriteToArmSMELowering,
+               VectorLoadStoreToArmSMELowering<vector::LoadOp>,
+               VectorLoadStoreToArmSMELowering<vector::StoreOp>>(&ctx);
 }
----------------
awarzynski wrote:
> IMHO, this would be less noisy (i.e. no `VectorLoadStoreToArmSMELowering` template):
> ```
> patterns.add<TransferWriteToArmSMELowering, VectorLoadToArmSME, VectorStoreToArmSME>(&ctx);
> ```
> This way every element in `patterns.add` would have a more distinct name. But ultimately, it's a matter of preference so go with whatever you prefer.
> IMHO, this would be less noisy (i.e. no `VectorLoadStoreToArmSMELowering` template):
> ```
> patterns.add<TransferWriteToArmSMELowering, VectorLoadToArmSME, VectorStoreToArmSME>(&ctx);
> ```
> This way every element in `patterns.add` would have a more distinct name. But ultimately, it's a matter of preference so go with whatever you prefer.

That's a good suggestion! Done


================
Comment at: mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp:97
 
-/// Lower 'arm_sme.store_tile' to a loop over the rows of ZA and store each row
-/// using 'arm_sme.intr.str'.
-///
-///  BEFORE:
-///  ```mlir
-///     arm_sme.tile_store %arg0[%c0, %c0], %0 : memref<?x?xi8>,
-///     vector<[16]x[16]xi8
-///  ```
-///
-///  AFTER:
-///  ```mlir
-///      %vscale = "llvm.intr.vscale"() : () -> index
-///      %c0 = arith.constant 0 : index
-///      %c1 = arith.constant 1 : index
-///      %c16 = arith.constant 16 : index
-///      %vec_size = arith.muli %c16, %vscale : index
-///      scf.for %row_idx = %c0 to %vec_size step %c1 {
-///        // (...)
-///        "arm_sme.intr.str"(%row_idx, %addr) : (i32, !llvm.ptr) -> ()
-///  ```
-struct TileStoreOpConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
-  using ConvertOpToLLVMPattern<TileStoreOp>::ConvertOpToLLVMPattern;
+/// Casts scalar integer tile id to i32 to feed into the SME intrinsics.
+Value castTileIDToI32(Value tile, Location loc,
----------------
awarzynski wrote:
> No casting happens in this routine :)
> No casting happens in this routine :)




================
Comment at: mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp:98-117
-/// Lower 'arm_sme.store_tile' to a loop over the rows of ZA and store each row
-/// using 'arm_sme.intr.str'.
-///
-///  BEFORE:
-///  ```mlir
-///     arm_sme.tile_store %arg0[%c0, %c0], %0 : memref<?x?xi8>,
-///     vector<[16]x[16]xi8
----------------
awarzynski wrote:
> Please preserve this comment. Fine details, imho, can be extracted from the code. But documenting the overall structure is helpful.
> Please preserve this comment. Fine details, imho, can be extracted from the code. But documenting the overall structure is helpful.




================
Comment at: mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp:108-109
+
+/// Returns `offset` if memref is rank 2, otherwise adjusts `offset` by SVL<t>
+/// bytes.
+Value getOffset(MemRefType memRefType, Value offset, Value vscale,
----------------
awarzynski wrote:
> What "offset" is it? Why do we need to adjust it in the 1-D case and just return as is in 2-D?
> 
> In the 1-D case, is it:
> ```
> offet * vscale * minElems
> ```
> ? And is `minElems` the minimum number of elements in a scalable vector? So basically the "base size of an SVE vector"?
> What "offset" is it? Why do we need to adjust it in the 1-D case and just return as is in 2-D?

The offset to the load or store pointer. In the 2D case `getStridedElementPtr` does the arithmetic for us, but in the 1D case we have to do it ourselves.

> In the 1-D case, is it:
> ```
> offet * vscale * minElems
> ```

Yeah, and `offset` is `vnum` so it's `vnum * vscale * minElems`. So the offset is the number of elements for a given type in a vector of SVL bits (SVLt), and this is scaled by `vnum`. So the base would get incremented a tile vector at a time.

> ? And is `minElems` the minimum number of elements in a scalable vector? So basically the "base size of an SVE vector"?

Yeah, so `128 / esize`. Perhaps `getMinNumElts` would be better implemented like that rather than with a switch actually.

As for where supporting for both 1D and 2D memrefs came from, I initially started with these integration tests that dump ZA but could get it to work with 2D memrefs: https://gist.github.com/c-rhodes/1e9f2d8fd0ca3c6539f167e08079f6ab

I found those tests useful for verification but since the output varies depending on the runtime VL we cant add these tests.



================
Comment at: mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp:163-164
+                                           {offset}, rewriter);
+    auto vnumI32 =
+        rewriter.create<arith::IndexCastUIOp>(loc, rewriter.getI32Type(), vnum);
+    auto one = rewriter.create<arith::ConstantOp>(
----------------
awarzynski wrote:
> The "horizontal" load instructions take "slice number": https://developer.arm.com/documentation/ddi0602/2023-06/SME-Instructions/LD1H--scalar-plus-scalar--tile-slice---Contiguous-load-of-halfwords-to-16-bit-element-ZA-tile-slice-?lang=en.
> 
> I would rename `vnumI32` as `sliceNumI32` so that this is easier to match with the spec.
> The "horizontal" load instructions take "slice number": https://developer.arm.com/documentation/ddi0602/2023-06/SME-Instructions/LD1H--scalar-plus-scalar--tile-slice---Contiguous-load-of-halfwords-to-16-bit-element-ZA-tile-slice-?lang=en.
> 
> I would rename `vnumI32` as `sliceNumI32` so that this is easier to match with the spec.

Good point, naming is difficult, updated to "tile slice", this seems consistent with LLVM as well.


================
Comment at: mlir/lib/Dialect/ArmSME/Utils/CMakeLists.txt:1
+add_mlir_dialect_library(MLIRArmSMEUtils
+  Utils.cpp
----------------
awarzynski wrote:
> Do we need a dedicated library for one CPP file? Perhaps it's sufficient to add this to `MLIRArmSMETransforms`?
> Do we need a dedicated library for one CPP file? Perhaps it's sufficient to add this to `MLIRArmSMETransforms`?

I copied this from another dialect and it seems all dialects with utils do this.


================
Comment at: mlir/lib/Dialect/ArmSME/Utils/Utils.cpp:31
+
+/// Utility to return minimum number of elements for the given element type in
+/// an SME array vector.
----------------
awarzynski wrote:
> Given that this code will only be used only by SME/SSVE, why not name this as:
> ```
> getSVEVectorBaseSize
> ```
> 
> This way it will be very clear that it's some special SME/SSVE hook. Also:
> 
> > in SME array vector
> 
> + "and an SVE vector"?
> Given that this code will only be used only by SME/SSVE, why not name this as:
> ```
> getSVEVectorBaseSize
> ```
> 
> This way it will be very clear that it's some special SME/SSVE hook. Also:
> 
> > in SME array vector
> 
> + "and an SVE vector"?

Naming this is tricky, I mean it to be the minSVLT which is the minimum number of elements in a vector of SVL bits, which when scaled by vscale gives both the number of tile slices (vector of SVL bits) in ZA and also the number of elements in a tile slice. I've updated it to `getSMETileSliceMinNumElts` but will have a think, also not sure we want to mention SVE here?


CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D155306/new/

https://reviews.llvm.org/D155306



More information about the llvm-commits mailing list