[PATCH] D154941: [mlir][ArmSME] Add custom get_tile_id and cast ops

Andrzej Warzynski via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 12 01:51:55 PDT 2023


awarzynski added a comment.

Great work @c-rhodes , thank you! I've actually immediately rebased https://reviews.llvm.org/D154867 on top of this change and that immediately solved the "data flow" issue 🙏🏻

Overall this looks solid to me. I've left a few minor suggestions - mostly to clarify the documentation.



================
Comment at: mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td:46-47
+      And<[IsVectorOfRankPred<[2]>,
+           CPred<[{::llvm::cast<::mlir::VectorType>($_self).allDimsScalable()}]>,
+           CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>({" # !interleave(dims, ", ") # "})">]>,
+  description>;
----------------
[nit] Perhaps extract the preconditions to dedicated definition?

Also, when would this be triggered:
```
 CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>({" # !interleave(dims, ", ") # "})"
```
It feels like a "VectorType verifier" that could be safely skipped (i.e. nothing SME specific).


================
Comment at: mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td:75-76
+    A `cast_tile_to_vector` operation does a cast from a tile id to a 2-d
+    scalable vector type, which represents an SME "virtual tile". This is used
+    in conjunction with `cast_vector_to_tile` to preserve dataflow and type
+    legality when lowering vector ops that have both inputs and outputs, to SME
----------------
Perhaps:
* "This is used in conjunction with `cast_vector_to_tile" -->  "This would normally be used in conjunction with "virtual tile load" operations to model the output of such Ops. This is required to preserve data-flow as SME intrinsics do not return values."

Basically, this Op and `CastVectorToTile` complement each other, right? And I guess that's what we want to say here? But IMHO, this description should focus on `CastTileToVector`. 


================
Comment at: mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td:81-96
+    ```mlir
+
+    // input
+    %tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+    vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+
+    // lower vector.load -> SME intrinsics
----------------
This example is a bit busy.I would focus on the Op that's defined here (i.e. `CastTileToVector`), so that this description is self-contained (try to avoid references to `CastVectorToTile`). My suggestion:

  EXAMPLE:
  Input:
  ```lang=cpp
      vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
      vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
  ```

  After lowering `vector.load`:
  ```lang=cpp
      %tile_id = arith.constant 0 : i32
      scf.for %vnum = %c0 to %num_vectors step %c1 {
        // ...
        "arm_sme.intr.ld1w.horiz"(%pg, %ptr, %tile_id, %vnum) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
      }
      %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
      vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
  ```

Another question - are `vector.load` and `vector.store` the right Ops here? We don't really lower from these ATM.


================
Comment at: mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td:105-108
+    The opposite is true for the `vector.store`, when lowered to intrinsics
+    they would be preceded by a `cast_vector_to_tile` op. Once the lowering is
+    complete the canonicalizer will fold the casts away. The
+    `cast_vector_to_tile` op example shows the other half of the lowering.
----------------
This comment refers to `CastVectorToTile`


================
Comment at: mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td:110-113
+    These casts are expected to be folded, but may persist if there's an
+    incomplete lowering where a vector op has been lowered to SME but the uses
+    haven't, much like if `-reconcile-unrealized-casts` fails. Currently these
+    cast ops cannot be lowered to LLVM, but may be in the future.
----------------
This comment refers to "these casts", but this is just one cast ;-)


================
Comment at: mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td:130-161
+    Example:
+    ```mlir
+
+    // input
+    %tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+    vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+
----------------
This example is a bit busy.I would focus on the Op that's defined here (i.e. `CastVectorToTile`), so that this description is self-contained (try to avoid references to `CastTileToVector`). My suggestion:


    EXAMPLE:
    Input:
    ```mlir

    %tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
    vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
    ```
 
    Output after lowering `vector.store`:
    ```mlir
    %tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
    
    %tile_id = arm_sme.cast_vector_to_tile %tile : : (vector<[4]x[4]xi32>) -> i32
    scf.for %vnum = %c0 to %num_vectors step %c1 {
      // ...
      "arm_sme.intr.st1w.horiz"(%pg, %ptr, %tile_id, %vnum) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
    }
    ```

    Additionally, canonicalization will look through `cast_vector_to_tile` Ops and fold the
    cast ops away if they come from `cast_tile_to_vector`.



================
Comment at: mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td:178
+    ```mlir
+    // Allocate an 8-bit element ZA tile
+    %za0_b = arm_sme.get_tile_id : i8
----------------
[nit] Is there any "tile allocation" really taking place? Perhaps "Allocate and return a "virtual tile" ID"?


================
Comment at: mlir/test/Dialect/ArmSME/canonicalize.mlir:10-11
+  // CHECK-NOT: arm_sme.cast_vector_to_tile
+  %tile = arm_sme.cast_tile_to_vector %tile_id_0 : i8 to vector<[16]x[16]xi8>
+  %tile_id_1 = arm_sme.cast_vector_to_tile %tile : vector<[16]x[16]xi8> to i8
+  // CHECK-NEXT: return %[[TILE_ID]] : i8
----------------
What about "the other way round"?
```
%tile_id = arm_sme.cast_vector_to_tile %tile_1 : vector<[16]x[16]xi8> to i8  
%tile_2 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi8>
  
```


================
Comment at: mlir/test/Dialect/ArmSME/invalid.mlir:5
+
+func.func @arm_sme_cast_tile_to_vector__bad_vector_type(%tile_id : i8) -> vector<[16]xi8> {
+  // expected-error at +1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]xi8>'}}
----------------
How about:
```
func.func @arm_sme_cast_tile_to_vector__bad_vector_type(%tile_id : i8) -> vector<[16]x16xi8> 
```
and other combinations? For example:
* `vector<[16]x[16]xi4>`
* `vector<16x[16]xi8>`


================
Comment at: mlir/test/Dialect/ArmSME/roundtrip.mlir:8
+  %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi8>
+  return %0 : vector<[16]x[16]xi8>
+}
----------------
Could you add one more other element type? For example, `vector<[1]x[1]xi128>` (i.e. the other extreme).


CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D154941/new/

https://reviews.llvm.org/D154941



More information about the llvm-commits mailing list