[PATCH] D154941: [mlir][ArmSME] Add custom get_tile_id and cast ops
Cullen Rhodes via Phabricator via llvm-commits
llvm-commits at lists.llvm.org
Wed Jul 12 08:49:16 PDT 2023
c-rhodes marked 10 inline comments as done and an inline comment as not done.
c-rhodes added inline comments.
================
Comment at: mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td:46-47
+ And<[IsVectorOfRankPred<[2]>,
+ CPred<[{::llvm::cast<::mlir::VectorType>($_self).allDimsScalable()}]>,
+ CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>({" # !interleave(dims, ", ") # "})">]>,
+ description>;
----------------
awarzynski wrote:
> [nit] Perhaps extract the preconditions to dedicated definition?
>
> Also, when would this be triggered:
> ```
> CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>({" # !interleave(dims, ", ") # "})"
> ```
> It feels like a "VectorType verifier" that could be safely skipped (i.e. nothing SME specific).
> Also, when would this be triggered:
> ```
> CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>({" # !interleave(dims, ", ") # "})"
> ```
> It feels like a "VectorType verifier" that could be safely skipped (i.e. nothing SME specific).
Please could you clarify, not sure what you mean? This verifies the shape, i.e. `vector<[16]x[16]xi8>` is (16, 16).
================
Comment at: mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td:75-76
+ A `cast_tile_to_vector` operation does a cast from a tile id to a 2-d
+ scalable vector type, which represents an SME "virtual tile". This is used
+ in conjunction with `cast_vector_to_tile` to preserve dataflow and type
+ legality when lowering vector ops that have both inputs and outputs, to SME
----------------
awarzynski wrote:
> Perhaps:
> * "This is used in conjunction with `cast_vector_to_tile" --> "This would normally be used in conjunction with "virtual tile load" operations to model the output of such Ops. This is required to preserve data-flow as SME intrinsics do not return values."
>
> Basically, this Op and `CastVectorToTile` complement each other, right? And I guess that's what we want to say here? But IMHO, this description should focus on `CastTileToVector`.
> Perhaps:
> * "This is used in conjunction with `cast_vector_to_tile" --> "This would normally be used in conjunction with "virtual tile load" operations to model the output of such Ops. This is required to preserve data-flow as SME intrinsics do not return values."
>
> Basically, this Op and `CastVectorToTile` complement each other, right? And I guess that's what we want to say here? But IMHO, this description should focus on `CastTileToVector`.
Thanks for the suggestion this has cleaned it up nicely
================
Comment at: mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td:178
+ ```mlir
+ // Allocate an 8-bit element ZA tile
+ %za0_b = arm_sme.get_tile_id : i8
----------------
awarzynski wrote:
> [nit] Is there any "tile allocation" really taking place? Perhaps "Allocate and return a "virtual tile" ID"?
> [nit] Is there any "tile allocation" really taking place? Perhaps "Allocate and return a "virtual tile" ID"?
There isn't from the perspective of the op I suppose, it's the pass that does that. Updated the comment.
================
Comment at: mlir/test/Dialect/ArmSME/canonicalize.mlir:10-11
+ // CHECK-NOT: arm_sme.cast_vector_to_tile
+ %tile = arm_sme.cast_tile_to_vector %tile_id_0 : i8 to vector<[16]x[16]xi8>
+ %tile_id_1 = arm_sme.cast_vector_to_tile %tile : vector<[16]x[16]xi8> to i8
+ // CHECK-NEXT: return %[[TILE_ID]] : i8
----------------
awarzynski wrote:
> What about "the other way round"?
> ```
> %tile_id = arm_sme.cast_vector_to_tile %tile_1 : vector<[16]x[16]xi8> to i8
> %tile_2 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi8>
>
> ```
> What about "the other way round"?
> ```
> %tile_id = arm_sme.cast_vector_to_tile %tile_1 : vector<[16]x[16]xi8> to i8
> %tile_2 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi8>
>
> ```
Good spot!
================
Comment at: mlir/test/Dialect/ArmSME/roundtrip.mlir:8
+ %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi8>
+ return %0 : vector<[16]x[16]xi8>
+}
----------------
awarzynski wrote:
> Could you add one more other element type? For example, `vector<[1]x[1]xi128>` (i.e. the other extreme).
> Could you add one more other element type? For example, `vector<[1]x[1]xi128>` (i.e. the other extreme).
I've added tests for all element types
CHANGES SINCE LAST ACTION
https://reviews.llvm.org/D154941/new/
https://reviews.llvm.org/D154941
More information about the llvm-commits
mailing list