[PATCH] D154941: [mlir][ArmSME] Add custom get_tile_id and cast ops

Andrzej Warzynski via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 17 02:36:10 PDT 2023


awarzynski added inline comments.


================
Comment at: mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td:46-47
+      And<[IsVectorOfRankPred<[2]>,
+           CPred<[{::llvm::cast<::mlir::VectorType>($_self).allDimsScalable()}]>,
+           CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>({" # !interleave(dims, ", ") # "})">]>,
+  description>;
----------------
c-rhodes wrote:
> awarzynski wrote:
> > c-rhodes wrote:
> > > c-rhodes wrote:
> > > > awarzynski wrote:
> > > > > c-rhodes wrote:
> > > > > > awarzynski wrote:
> > > > > > > [nit] Perhaps extract the preconditions to dedicated definition?
> > > > > > > 
> > > > > > > Also, when would this be triggered:
> > > > > > > ```
> > > > > > >  CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>({" # !interleave(dims, ", ") # "})"
> > > > > > > ```
> > > > > > > It feels like a "VectorType verifier" that could be safely skipped (i.e. nothing SME specific).
> > > > > > > Also, when would this be triggered:
> > > > > > > ```
> > > > > > >  CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>({" # !interleave(dims, ", ") # "})"
> > > > > > > ```
> > > > > > > It feels like a "VectorType verifier" that could be safely skipped (i.e. nothing SME specific).
> > > > > > 
> > > > > > Please could you clarify, not sure what you mean? This verifies the shape, i.e. `vector<[16]x[16]xi8>` is (16, 16).
> > > > > I am just thinking that every vector that you create like this will satisfy this condition and to me this check feels redundant. But I am probably just failing to understand the underlying rationale. No harm in keeping this.
> > > > > I am just thinking that every vector that you create like this will satisfy this condition and to me this check feels redundant.
> > > > 
> > > > And this verifies that :)
> > > > 
> > > > 
> > > > > I am just thinking that every vector that you create like this will satisfy this condition and to me this check feels redundant.
> > > > 
> > > > And this verifies that :)
> > > > 
> > > > 
> > > 
> > > To clarify, without this check:
> > > ```
> > > %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[8]xi32>
> > > ```
> > > 
> > > would be valid
> > > To clarify, without this check:
> > > ```%tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[8]xi32>```
> > > would be valid
> > 
> > Could you double-check? This works fine:
> > 
> > ```
> > module {
> >   func.func @arm_sme_cast_tile_to_vector_i8(%arg0: i8) -> vector<[4]x[16]xi8> {
> >     %0 = arm_sme.cast_tile_to_vector %arg0 : i8 to vector<[4]x[16]xi8>
> >     return %0 : vector<[4]x[16]xi8>
> >   }
> > }
> > ```
> > 
> > You will need to replace `SMETile` with `AnyVectorOfAnyRank` in the definition of `CastTileToVector`, but `IsVectorOfShape` should trigger in both cases, right?
> > > To clarify, without this check:
> > > ```%tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[8]xi32>```
> > > would be valid
> > 
> > Could you double-check? This works fine:
> 
> > 
> > ```
> > module {
> >   func.func @arm_sme_cast_tile_to_vector_i8(%arg0: i8) -> vector<[4]x[16]xi8> {
> >     %0 = arm_sme.cast_tile_to_vector %arg0 : i8 to vector<[4]x[16]xi8>
> >     return %0 : vector<[4]x[16]xi8>
> >   }
> > }
> > ```
> 
> this fails for me (as expected):
> 
> ```build/bin/mlir-opt foo.mlir
> foo.mlir:4:8: error: 'arm_sme.cast_tile_to_vector' op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[4]x[16]xi8>'
>   %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[4]x[16]xi8>
>        ^
> foo.mlir:4:8: note: see current operation: %0 = "arm_sme.cast_tile_to_vector"(%arg0) : (i8) -> vector<[4]x[16]xi8>
> ```
> 
> and doesn't if I remove `IsVectorOfShape<dims>` check.
> 
> > 
> > You will need to replace `SMETile` with `AnyVectorOfAnyRank` in the definition of `CastTileToVector`, but `IsVectorOfShape` should trigger in both cases, right?
> 
> I'm not sure I follow, please could you clarify?
You are right and I am wrong, sorry.

I've just checked the generated CPP code and it's this:
```
((::llvm::cast<::mlir::VectorType>(type).getShape() == ArrayRef<int64_t>({16, 16})))
// other similar checks
```
So RHS is taken from:
```
  def SMETile : AnyTypeOf<[nxnxv16i8, nxnxv8i16, nxnxv4i32, nxnxv2i64, nxnxv1i128,
                           nxnxv8f16, nxnxv8bf16, nxnxv4f32, nxnxv2f64]>;
```

I thought that for this example (`vector<[4]x[8]xi32>`) it would check the following instead:
```
((::llvm::cast<::mlir::VectorType>(type).getShape() == ArrayRef<int64_t>({4, 8})))
```
i.e. take the RHS from the input (`vector<[4]x[8]xi32>`). Hence the confusion.


CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D154941/new/

https://reviews.llvm.org/D154941



More information about the llvm-commits mailing list