[PATCH] D154941: [mlir][ArmSME] Add custom get_tile_id and cast ops
Andrzej Warzynski via Phabricator via llvm-commits
llvm-commits at lists.llvm.org
Mon Jul 17 02:36:10 PDT 2023
awarzynski added inline comments.
================
Comment at: mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td:46-47
+ And<[IsVectorOfRankPred<[2]>,
+ CPred<[{::llvm::cast<::mlir::VectorType>($_self).allDimsScalable()}]>,
+ CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>({" # !interleave(dims, ", ") # "})">]>,
+ description>;
----------------
c-rhodes wrote:
> awarzynski wrote:
> > c-rhodes wrote:
> > > c-rhodes wrote:
> > > > awarzynski wrote:
> > > > > c-rhodes wrote:
> > > > > > awarzynski wrote:
> > > > > > > [nit] Perhaps extract the preconditions to dedicated definition?
> > > > > > >
> > > > > > > Also, when would this be triggered:
> > > > > > > ```
> > > > > > > CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>({" # !interleave(dims, ", ") # "})"
> > > > > > > ```
> > > > > > > It feels like a "VectorType verifier" that could be safely skipped (i.e. nothing SME specific).
> > > > > > > Also, when would this be triggered:
> > > > > > > ```
> > > > > > > CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>({" # !interleave(dims, ", ") # "})"
> > > > > > > ```
> > > > > > > It feels like a "VectorType verifier" that could be safely skipped (i.e. nothing SME specific).
> > > > > >
> > > > > > Please could you clarify, not sure what you mean? This verifies the shape, i.e. `vector<[16]x[16]xi8>` is (16, 16).
> > > > > I am just thinking that every vector that you create like this will satisfy this condition and to me this check feels redundant. But I am probably just failing to understand the underlying rationale. No harm in keeping this.
> > > > > I am just thinking that every vector that you create like this will satisfy this condition and to me this check feels redundant.
> > > >
> > > > And this verifies that :)
> > > >
> > > >
> > > > > I am just thinking that every vector that you create like this will satisfy this condition and to me this check feels redundant.
> > > >
> > > > And this verifies that :)
> > > >
> > > >
> > >
> > > To clarify, without this check:
> > > ```
> > > %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[8]xi32>
> > > ```
> > >
> > > would be valid
> > > To clarify, without this check:
> > > ```%tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[8]xi32>```
> > > would be valid
> >
> > Could you double-check? This works fine:
> >
> > ```
> > module {
> > func.func @arm_sme_cast_tile_to_vector_i8(%arg0: i8) -> vector<[4]x[16]xi8> {
> > %0 = arm_sme.cast_tile_to_vector %arg0 : i8 to vector<[4]x[16]xi8>
> > return %0 : vector<[4]x[16]xi8>
> > }
> > }
> > ```
> >
> > You will need to replace `SMETile` with `AnyVectorOfAnyRank` in the definition of `CastTileToVector`, but `IsVectorOfShape` should trigger in both cases, right?
> > > To clarify, without this check:
> > > ```%tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[8]xi32>```
> > > would be valid
> >
> > Could you double-check? This works fine:
>
> >
> > ```
> > module {
> > func.func @arm_sme_cast_tile_to_vector_i8(%arg0: i8) -> vector<[4]x[16]xi8> {
> > %0 = arm_sme.cast_tile_to_vector %arg0 : i8 to vector<[4]x[16]xi8>
> > return %0 : vector<[4]x[16]xi8>
> > }
> > }
> > ```
>
> this fails for me (as expected):
>
> ```build/bin/mlir-opt foo.mlir
> foo.mlir:4:8: error: 'arm_sme.cast_tile_to_vector' op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[4]x[16]xi8>'
> %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[4]x[16]xi8>
> ^
> foo.mlir:4:8: note: see current operation: %0 = "arm_sme.cast_tile_to_vector"(%arg0) : (i8) -> vector<[4]x[16]xi8>
> ```
>
> and doesn't if I remove `IsVectorOfShape<dims>` check.
>
> >
> > You will need to replace `SMETile` with `AnyVectorOfAnyRank` in the definition of `CastTileToVector`, but `IsVectorOfShape` should trigger in both cases, right?
>
> I'm not sure I follow, please could you clarify?
You are right and I am wrong, sorry.
I've just checked the generated CPP code and it's this:
```
((::llvm::cast<::mlir::VectorType>(type).getShape() == ArrayRef<int64_t>({16, 16})))
// other similar checks
```
So RHS is taken from:
```
def SMETile : AnyTypeOf<[nxnxv16i8, nxnxv8i16, nxnxv4i32, nxnxv2i64, nxnxv1i128,
nxnxv8f16, nxnxv8bf16, nxnxv4f32, nxnxv2f64]>;
```
I thought that for this example (`vector<[4]x[8]xi32>`) it would check the following instead:
```
((::llvm::cast<::mlir::VectorType>(type).getShape() == ArrayRef<int64_t>({4, 8})))
```
i.e. take the RHS from the input (`vector<[4]x[8]xi32>`). Hence the confusion.
CHANGES SINCE LAST ACTION
https://reviews.llvm.org/D154941/new/
https://reviews.llvm.org/D154941
More information about the llvm-commits
mailing list