================
@@ -2082,23 +2082,27 @@ def Vector_GatherOp :
3-D and the result is 2-D:
```mlir
- func.func @gather_3D_to_2D(
- %base: memref<?x10x?xf32>, %ofs_0: index, %ofs_1: index, %ofs_2: index,
- %indices: vector<2x3xi32>, %mask: vector<2x3xi1>,
- %fall_thru: vector<2x3xf32>) -> vector<2x3xf32> {
- %result = vector.gather %base[%ofs_0, %ofs_1, %ofs_2]
- [%indices], %mask, %fall_thru : [...]
- return %result : vector<2x3xf32>
+ %base: memref<?x10x?xf32>, %ofs_0: index, %ofs_1: index, %ofs_2: index,
+ %indices: vector<2x3xi32>, %mask: vector<2x3xi1>,
+ %fall_thru: vector<2x3xf32>) -> vector<2x3xf32> {
+ %result = vector.gather %base[%ofs_0, %ofs_1, %ofs_2]
+ [%indices], %mask, %fall_thru : [...]
}
```
The indexing semantics are then,
```
- result[i,j] := if mask[i,j] then base[i0, i1, i2 + indices[i,j]]
+ result[i,j] := if mask[i,j] then base[%ofs_0, %ofs_1, %ofs_2 + indices[i,j]]
else pass_thru[i,j]
```
- The index into `base` only varies in the innermost ((k-1)-th) dimension.
+ Note, `indices` are element offsets - they are expressed in units of
+ elements (not bytes). Each element in `indices` represents a displacement
+ in units of elements from the starting element, i.e. `%base[%ofs_0, %ofs_1,
+ %ofs_2]` for the example above. Importantly, for MemRefs, `indices` are
+ %interpreted assuming an identity (contiguous) MemRef
+ layout and do not account for non-identity strides.
----------------
banach-space wrote:
By “do not account”, I mean that the indices are interpreted purely in terms of logical element offsets, independent of the MemRef’s physical layout.
In particular:
* this is not an error, and does not require verifier changes, and
* the semantics are defined in terms of the logical view of the MemRef, as if it had an identity (contiguous) layout.
Any non-identity layout (i.e. strides) is handled during lowering, where the logical indexing is mapped to physical addresses.
I’ll update the wording to make this clearer.
https://github.com/llvm/llvm-project/pull/181357