[PATCH] D154941: [mlir][ArmSME] Add custom get_tile_id and cast ops
Andrzej Warzynski via Phabricator via llvm-commits
llvm-commits at lists.llvm.org
Wed Jul 12 01:51:55 PDT 2023
awarzynski added a comment.
Great work @c-rhodes , thank you! I've actually immediately rebased https://reviews.llvm.org/D154867 on top of this change and that immediately solved the "data flow" issue 🙏🏻
Overall this looks solid to me. I've left a few minor suggestions - mostly to clarify the documentation.
================
Comment at: mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td:46-47
+ And<[IsVectorOfRankPred<[2]>,
+ CPred<[{::llvm::cast<::mlir::VectorType>($_self).allDimsScalable()}]>,
+ CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>({" # !interleave(dims, ", ") # "})">]>,
+ description>;
----------------
[nit] Perhaps extract the preconditions to dedicated definition?
Also, when would this be triggered:
```
CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>({" # !interleave(dims, ", ") # "})"
```
It feels like a "VectorType verifier" that could be safely skipped (i.e. nothing SME specific).
================
Comment at: mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td:75-76
+ A `cast_tile_to_vector` operation does a cast from a tile id to a 2-d
+ scalable vector type, which represents an SME "virtual tile". This is used
+ in conjunction with `cast_vector_to_tile` to preserve dataflow and type
+ legality when lowering vector ops that have both inputs and outputs, to SME
----------------
Perhaps:
* "This is used in conjunction with `cast_vector_to_tile" --> "This would normally be used in conjunction with "virtual tile load" operations to model the output of such Ops. This is required to preserve data-flow as SME intrinsics do not return values."
Basically, this Op and `CastVectorToTile` complement each other, right? And I guess that's what we want to say here? But IMHO, this description should focus on `CastTileToVector`.
================
Comment at: mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td:81-96
+ ```mlir
+
+ // input
+ %tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+ vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+
+ // lower vector.load -> SME intrinsics
----------------
This example is a bit busy.I would focus on the Op that's defined here (i.e. `CastTileToVector`), so that this description is self-contained (try to avoid references to `CastVectorToTile`). My suggestion:
EXAMPLE:
Input:
```lang=cpp
vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
```
After lowering `vector.load`:
```lang=cpp
%tile_id = arith.constant 0 : i32
scf.for %vnum = %c0 to %num_vectors step %c1 {
// ...
"arm_sme.intr.ld1w.horiz"(%pg, %ptr, %tile_id, %vnum) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
}
%tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
```
Another question - are `vector.load` and `vector.store` the right Ops here? We don't really lower from these ATM.
================
Comment at: mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td:105-108
+ The opposite is true for the `vector.store`, when lowered to intrinsics
+ they would be preceded by a `cast_vector_to_tile` op. Once the lowering is
+ complete the canonicalizer will fold the casts away. The
+ `cast_vector_to_tile` op example shows the other half of the lowering.
----------------
This comment refers to `CastVectorToTile`
================
Comment at: mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td:110-113
+ These casts are expected to be folded, but may persist if there's an
+ incomplete lowering where a vector op has been lowered to SME but the uses
+ haven't, much like if `-reconcile-unrealized-casts` fails. Currently these
+ cast ops cannot be lowered to LLVM, but may be in the future.
----------------
This comment refers to "these casts", but this is just one cast ;-)
================
Comment at: mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td:130-161
+ Example:
+ ```mlir
+
+ // input
+ %tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+ vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+
----------------
This example is a bit busy.I would focus on the Op that's defined here (i.e. `CastVectorToTile`), so that this description is self-contained (try to avoid references to `CastTileToVector`). My suggestion:
EXAMPLE:
Input:
```mlir
%tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
```
Output after lowering `vector.store`:
```mlir
%tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
%tile_id = arm_sme.cast_vector_to_tile %tile : : (vector<[4]x[4]xi32>) -> i32
scf.for %vnum = %c0 to %num_vectors step %c1 {
// ...
"arm_sme.intr.st1w.horiz"(%pg, %ptr, %tile_id, %vnum) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
}
```
Additionally, canonicalization will look through `cast_vector_to_tile` Ops and fold the
cast ops away if they come from `cast_tile_to_vector`.
================
Comment at: mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td:178
+ ```mlir
+ // Allocate an 8-bit element ZA tile
+ %za0_b = arm_sme.get_tile_id : i8
----------------
[nit] Is there any "tile allocation" really taking place? Perhaps "Allocate and return a "virtual tile" ID"?
================
Comment at: mlir/test/Dialect/ArmSME/canonicalize.mlir:10-11
+ // CHECK-NOT: arm_sme.cast_vector_to_tile
+ %tile = arm_sme.cast_tile_to_vector %tile_id_0 : i8 to vector<[16]x[16]xi8>
+ %tile_id_1 = arm_sme.cast_vector_to_tile %tile : vector<[16]x[16]xi8> to i8
+ // CHECK-NEXT: return %[[TILE_ID]] : i8
----------------
What about "the other way round"?
```
%tile_id = arm_sme.cast_vector_to_tile %tile_1 : vector<[16]x[16]xi8> to i8
%tile_2 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi8>
```
================
Comment at: mlir/test/Dialect/ArmSME/invalid.mlir:5
+
+func.func @arm_sme_cast_tile_to_vector__bad_vector_type(%tile_id : i8) -> vector<[16]xi8> {
+ // expected-error at +1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]xi8>'}}
----------------
How about:
```
func.func @arm_sme_cast_tile_to_vector__bad_vector_type(%tile_id : i8) -> vector<[16]x16xi8>
```
and other combinations? For example:
* `vector<[16]x[16]xi4>`
* `vector<16x[16]xi8>`
================
Comment at: mlir/test/Dialect/ArmSME/roundtrip.mlir:8
+ %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi8>
+ return %0 : vector<[16]x[16]xi8>
+}
----------------
Could you add one more other element type? For example, `vector<[1]x[1]xi128>` (i.e. the other extreme).
CHANGES SINCE LAST ACTION
https://reviews.llvm.org/D154941/new/
https://reviews.llvm.org/D154941
More information about the llvm-commits
mailing list