[PATCH] D154941: [mlir][ArmSME] Add custom get_tile_id and cast ops

Cullen Rhodes via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 17 01:30:09 PDT 2023


c-rhodes added inline comments.


================
Comment at: mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td:46-47
+      And<[IsVectorOfRankPred<[2]>,
+           CPred<[{::llvm::cast<::mlir::VectorType>($_self).allDimsScalable()}]>,
+           CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>({" # !interleave(dims, ", ") # "})">]>,
+  description>;
----------------
awarzynski wrote:
> c-rhodes wrote:
> > c-rhodes wrote:
> > > awarzynski wrote:
> > > > c-rhodes wrote:
> > > > > awarzynski wrote:
> > > > > > [nit] Perhaps extract the preconditions to dedicated definition?
> > > > > > 
> > > > > > Also, when would this be triggered:
> > > > > > ```
> > > > > >  CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>({" # !interleave(dims, ", ") # "})"
> > > > > > ```
> > > > > > It feels like a "VectorType verifier" that could be safely skipped (i.e. nothing SME specific).
> > > > > > Also, when would this be triggered:
> > > > > > ```
> > > > > >  CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>({" # !interleave(dims, ", ") # "})"
> > > > > > ```
> > > > > > It feels like a "VectorType verifier" that could be safely skipped (i.e. nothing SME specific).
> > > > > 
> > > > > Please could you clarify, not sure what you mean? This verifies the shape, i.e. `vector<[16]x[16]xi8>` is (16, 16).
> > > > I am just thinking that every vector that you create like this will satisfy this condition and to me this check feels redundant. But I am probably just failing to understand the underlying rationale. No harm in keeping this.
> > > > I am just thinking that every vector that you create like this will satisfy this condition and to me this check feels redundant.
> > > 
> > > And this verifies that :)
> > > 
> > > 
> > > > I am just thinking that every vector that you create like this will satisfy this condition and to me this check feels redundant.
> > > 
> > > And this verifies that :)
> > > 
> > > 
> > 
> > To clarify, without this check:
> > ```
> > %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[8]xi32>
> > ```
> > 
> > would be valid
> > To clarify, without this check:
> > ```%tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[8]xi32>```
> > would be valid
> 
> Could you double-check? This works fine:
> 
> ```
> module {
>   func.func @arm_sme_cast_tile_to_vector_i8(%arg0: i8) -> vector<[4]x[16]xi8> {
>     %0 = arm_sme.cast_tile_to_vector %arg0 : i8 to vector<[4]x[16]xi8>
>     return %0 : vector<[4]x[16]xi8>
>   }
> }
> ```
> 
> You will need to replace `SMETile` with `AnyVectorOfAnyRank` in the definition of `CastTileToVector`, but `IsVectorOfShape` should trigger in both cases, right?
> > To clarify, without this check:
> > ```%tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[8]xi32>```
> > would be valid
> 
> Could you double-check? This works fine:

> 
> ```
> module {
>   func.func @arm_sme_cast_tile_to_vector_i8(%arg0: i8) -> vector<[4]x[16]xi8> {
>     %0 = arm_sme.cast_tile_to_vector %arg0 : i8 to vector<[4]x[16]xi8>
>     return %0 : vector<[4]x[16]xi8>
>   }
> }
> ```

this fails for me (as expected):

```build/bin/mlir-opt foo.mlir
foo.mlir:4:8: error: 'arm_sme.cast_tile_to_vector' op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[4]x[16]xi8>'
  %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[4]x[16]xi8>
       ^
foo.mlir:4:8: note: see current operation: %0 = "arm_sme.cast_tile_to_vector"(%arg0) : (i8) -> vector<[4]x[16]xi8>
```

and doesn't if I remove `IsVectorOfShape<dims>` check.

> 
> You will need to replace `SMETile` with `AnyVectorOfAnyRank` in the definition of `CastTileToVector`, but `IsVectorOfShape` should trigger in both cases, right?

I'm not sure I follow, please could you clarify?


CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D154941/new/

https://reviews.llvm.org/D154941



More information about the llvm-commits mailing list