[PATCH] D155306: [mlir][ArmSME] Add tile load op and extend tile store tile size support

Andrzej Warzynski via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 19 00:47:39 PDT 2023


awarzynski added a comment.

Overall looks good, thanks! We definitely need to take care of the loop materialisation after this change (i.e. move it higher up the compilation stack).

I've left a fair few comments, but nothing major and this is a rather large patch. I still need to go over the tests.

Thanks for working on this!



================
Comment at: mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td:229
+  let description = [{
+    Load a 2D SME "virtual tile" to memory.
+
----------------



================
Comment at: mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td:237
+  }];
+  let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
+                       Variadic<Index>:$indices);
----------------



================
Comment at: mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h:24
+
+/// Utility to return bitwidth of type which should be an integer or float.
+unsigned getWidth(Type type);
----------------



================
Comment at: mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h:27
+
+/// Utility to return minimum number of elements for the given element type in
+/// an SME array vector.
----------------



================
Comment at: mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt:14
   MLIRLLVMCommonConversion
+  MLIRArmSMEUtils
   )
----------------
sort


================
Comment at: mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp:102-120
+    auto vType = loadOrStoreOp.getVectorType();
+    if ((vType.getRank() != 2) && vType.allDimsScalable())
+      return failure();
+
+    // TODO: add support for i128.
+    auto elemType = vType.getElementType();
+    if ((elemType != rewriter.getI8Type()) &&
----------------
I'd move this to an utility function - I expect that we'll be needing this for other Ops as well.


================
Comment at: mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp:133-134
+  patterns.add<TransferWriteToArmSMELowering,
+               VectorLoadStoreToArmSMELowering<vector::LoadOp>,
+               VectorLoadStoreToArmSMELowering<vector::StoreOp>>(&ctx);
 }
----------------
IMHO, this would be less noisy (i.e. no `VectorLoadStoreToArmSMELowering` template):
```
patterns.add<TransferWriteToArmSMELowering, VectorLoadToArmSME, VectorStoreToArmSME>(&ctx);
```
This way every element in `patterns.add` would have a more distinct name. But ultimately, it's a matter of preference so go with whatever you prefer.


================
Comment at: mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp:97
 
-/// Lower 'arm_sme.store_tile' to a loop over the rows of ZA and store each row
-/// using 'arm_sme.intr.str'.
-///
-///  BEFORE:
-///  ```mlir
-///     arm_sme.tile_store %arg0[%c0, %c0], %0 : memref<?x?xi8>,
-///     vector<[16]x[16]xi8
-///  ```
-///
-///  AFTER:
-///  ```mlir
-///      %vscale = "llvm.intr.vscale"() : () -> index
-///      %c0 = arith.constant 0 : index
-///      %c1 = arith.constant 1 : index
-///      %c16 = arith.constant 16 : index
-///      %vec_size = arith.muli %c16, %vscale : index
-///      scf.for %row_idx = %c0 to %vec_size step %c1 {
-///        // (...)
-///        "arm_sme.intr.str"(%row_idx, %addr) : (i32, !llvm.ptr) -> ()
-///  ```
-struct TileStoreOpConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
-  using ConvertOpToLLVMPattern<TileStoreOp>::ConvertOpToLLVMPattern;
+/// Casts scalar integer tile id to i32 to feed into the SME intrinsics.
+Value castTileIDToI32(Value tile, Location loc,
----------------
No casting happens in this routine :)


================
Comment at: mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp:98-117
-/// Lower 'arm_sme.store_tile' to a loop over the rows of ZA and store each row
-/// using 'arm_sme.intr.str'.
-///
-///  BEFORE:
-///  ```mlir
-///     arm_sme.tile_store %arg0[%c0, %c0], %0 : memref<?x?xi8>,
-///     vector<[16]x[16]xi8
----------------
Please preserve this comment. Fine details, imho, can be extracted from the code. But documenting the overall structure is helpful.


================
Comment at: mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp:102
+    return rewriter.create<arith::ExtUIOp>(loc, rewriter.getI32Type(), tile);
+  else if (arm_sme::getWidth(tile.getType()) > 32)
+    return rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), tile);
----------------
[[ https://llvm.org/docs/CodingStandards.html#don-t-use-else-after-a-return | don't use else after a return ]] ;-)

Similar comment for `getOffset`.


================
Comment at: mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp:108-109
+
+/// Returns `offset` if memref is rank 2, otherwise adjusts `offset` by SVL<t>
+/// bytes.
+Value getOffset(MemRefType memRefType, Value offset, Value vscale,
----------------
What "offset" is it? Why do we need to adjust it in the 1-D case and just return as is in 2-D?

In the 1-D case, is it:
```
offet * vscale * minElems
```
? And is `minElems` the minimum number of elements in a scalable vector? So basically the "base size of an SVE vector"?


================
Comment at: mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp:163-164
+                                           {offset}, rewriter);
+    auto vnumI32 =
+        rewriter.create<arith::IndexCastUIOp>(loc, rewriter.getI32Type(), vnum);
+    auto one = rewriter.create<arith::ConstantOp>(
----------------
The "horizontal" load instructions take "slice number": https://developer.arm.com/documentation/ddi0602/2023-06/SME-Instructions/LD1H--scalar-plus-scalar--tile-slice---Contiguous-load-of-halfwords-to-16-bit-element-ZA-tile-slice-?lang=en.

I would rename `vnumI32` as `sliceNumI32` so that this is easier to match with the spec.


================
Comment at: mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp:173-186
+    if (width == 8)
+      rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(loc, mask, ptr, tileCast,
+                                                       vnumI32);
+    else if (width == 16)
+      rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(loc, mask, ptr, tileCast,
+                                                       vnumI32);
+    else if (width == 32)
----------------
Could this be a switch statement instead?


================
Comment at: mlir/lib/Dialect/ArmSME/Utils/CMakeLists.txt:1
+add_mlir_dialect_library(MLIRArmSMEUtils
+  Utils.cpp
----------------
Do we need a dedicated library for one CPP file? Perhaps it's sufficient to add this to `MLIRArmSMETransforms`?


================
Comment at: mlir/lib/Dialect/ArmSME/Utils/Utils.cpp:20
+
+/// Utility to return bitwidth of type which should be an integer or float.
+unsigned mlir::arm_sme::getWidth(Type type) {
----------------
This repeats the comment from the header file - it will become out of sync if somebody (e.g. me) forgets that and only updates one copy. IMHO, it's fine to limit the comments to where the interface is defined (i.e. the header file). 


================
Comment at: mlir/lib/Dialect/ArmSME/Utils/Utils.cpp:23
+  if (auto integerType = dyn_cast<IntegerType>(type))
+    return integerType.getWidth();
+  else if (auto floatType = dyn_cast<FloatType>(type))
----------------
Avoid `else` after `return`.


================
Comment at: mlir/lib/Dialect/ArmSME/Utils/Utils.cpp:31
+
+/// Utility to return minimum number of elements for the given element type in
+/// an SME array vector.
----------------
Given that this code will only be used only by SME/SSVE, why not name this as:
```
getSVEVectorBaseSize
```

This way it will be very clear that it's some special SME/SSVE hook. Also:

> in SME array vector

+ "and an SVE vector"?


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D155306/new/

https://reviews.llvm.org/D155306



More information about the llvm-commits mailing list