[Mlir-commits] [mlir] [mlir][vector] Make gather/scatter index dimensions separately (PR #194395)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Apr 27 07:58:07 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: Krzysztof Drewniak (krzysz00)
<details>
<summary>Changes</summary>
This commit updates the semantics of vector.gather and vector.scatter into something that makes multi-dimensional memrefs/tensors easier to work with and resolves the inconsistencies highlighted by #<!-- -->187215 by allowing both possible semantics.
Previously, vector.gather / vector.scatter had one "index_vec" argument, which specified the offsets to gather at once you'd reached the start location given by base[offsets]. While it was agreed that these indices would be interpreted in terms of the memref's layout in that you'd apply the innermost memref stride, but it wasn't agreed if having indices that went past the dimension was meant to index into the underlying memory or into the next dimension of the subview/strided memref.
To resolve this ambiguity and get rid of a bunch of complex delinearization logic, vector.gather and vector.scatter now take 1 <= r <= [the nank of the base] index vectors which specify how we're gathering/scattering along the last r memref dimensions.
This means that you can manipulate the indexing into each memref dimension separately without needing to break up the indices with possibly dynamic values - for example, folding an expand_shape into a gather will now can look like just merging together the relevant gather indices.
This commit also clarifies the required relationship between the index_vecs and the strides of the memref of the index vecs use a short integer type - you must be able to truncate the relevant memref strides to match the index vec types.
This commit also updates the pattern that folds certain subviews into gathers to use the new semantics (and therefore support more strides), updates the lowerings to vector.loads along with the lowerings to LLVM to account for the new semantics, fixes some XeGPU patterns, and updates the one upstream creator of vector.gather to use the new form.
It also adds a canonicalization that moves splat indices from the index_vecs into the offsets.
---
Patch is 103.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/194395.diff
16 Files Affected:
- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+90-43)
- (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+53-8)
- (modified) mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp (+24-12)
- (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+16-39)
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+152-10)
- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp (+88-101)
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp (+7-3)
- (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir (+2)
- (modified) mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir (+51-56)
- (modified) mlir/test/Dialect/Linalg/vectorization/extract.mlir (+20-35)
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+39)
- (modified) mlir/test/Dialect/Vector/invalid.mlir (+72)
- (modified) mlir/test/Dialect/Vector/ops.mlir (+20)
- (modified) mlir/test/Dialect/Vector/vector-gather-lowering.mlir (+151-33)
- (modified) mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir (+3-8)
- (modified) mlir/test/Integration/Dialect/Vector/CPU/gather.mlir (+30-1)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index fdde3995f6333..794ae192bb5a4 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2057,24 +2057,29 @@ def Vector_MaskedStoreOp :
];
}
+def Vector_AtLeastOneIntegerVec : Variadic<VectorOfNonZeroRankOf<[AnyInteger, Index]>> {
+ let minSize = 1;
+}
+
def Vector_GatherOp :
Vector_Op<"gather", [
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
- DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
+ DeclareOpInterfaceMethods<AlignmentAttrOpInterface>,
+ AttrSizedOperandSegments
]>,
Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemRead]>:$base,
Variadic<Index>:$offsets,
- VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
+ Vector_AtLeastOneIntegerVec:$indices,
VectorOfNonZeroRankOf<[I1]>:$mask,
AnyVectorOfNonZeroRank:$pass_thru,
OptionalAttr<IntValidAlignment<I64Attr>>: $alignment)>,
Results<(outs AnyVectorOfNonZeroRank:$result)> {
let summary = [{
- Gathers elements from memory or ranked tensor into a vector as defined by an
- index vector and a mask vector.
+ Gathers elements from memory or ranked tensor into a vector as defined by
+ index vectors and a mask vector.
}];
let description = [{
@@ -2083,17 +2088,18 @@ def Vector_GatherOp :
on the values of an n-D mask vector.
If a mask bit is set, the corresponding result element is taken from `base`
- at an index defined by k indices and n-D `index_vec`. Otherwise, the element
- is taken from the pass-through vector. As an example, suppose that `base` is
- 3-D and the result is 2-D:
+ at an index defined by k indices and the values in r (for gather rank) n-D
+ `index_vec`s. Otherwise, the element is taken from the pass-through vector.
+ As an example, suppose that `base` is 3-D and the result is 2-D,
+ where we have r=2 index vectors:
```mlir
func.func @gather_3D_to_2D(
%base: memref<?x10x?xf32>, %ofs_0: index, %ofs_1: index, %ofs_2: index,
- %indices: vector<2x3xi32>, %mask: vector<2x3xi1>,
- %fall_thru: vector<2x3xf32>) -> vector<2x3xf32> {
+ %indices0: vector<2x3xi32>, %indices1: vector<2x3xi32>,
+ %mask: vector<2x3xi1>, %fall_thru: vector<2x3xf32>) -> vector<2x3xf32> {
%result = vector.gather %base[%ofs_0, %ofs_1, %ofs_2]
- [%indices], %mask, %fall_thru : [...]
+ [%indices0, %indices1], %mask, %fall_thru : [...]
return %result : vector<2x3xf32>
}
```
@@ -2101,10 +2107,10 @@ def Vector_GatherOp :
The indexing semantics are then,
```
- result[i,j] := if mask[i,j] then base[i0, i1, i2 + indices[i,j]]
+ result[i,j] := if mask[i,j] then base[i0, i1 + indices0[i, j], i2 + indices1[i, j]]
else pass_thru[i,j]
```
- The index into `base` only varies in the innermost ((k-1)-th) dimension.
+ The index into `base` only varies in the r innermost ((k-r)-th) dimensions.
If a mask bit is set and the corresponding index is out-of-bounds for the
given base, the behavior is undefined. If a mask bit is not set, the value
@@ -2115,6 +2121,16 @@ def Vector_GatherOp :
during progressively lowering to bring other memory operations closer to
hardware ISA support for a gather.
+ If the type of the memref offsets is potentially larger than the type of the
+ gather indices, it is assumed that the final `r` strides of the memref can be
+ truncated to the index type such that the linear gather index can be computed
+ without signed overflow (i.e. the truncated stride sign-extends back to the
+ original value). If an extension to the offset type is needed, the gather
+ indices are interpreted as signed integers.
+
+ Similarly, if the gather indices are wider than the memref offset type,
+ they must be losslessly truncatable to that type.
+
An optional `alignment` attribute allows to specify the byte alignment of the
gather operation. It must be a positive power of 2. The operation must access
memory at an address aligned to this boundary. Violating this requirement
@@ -2127,37 +2143,43 @@ def Vector_GatherOp :
%0 = vector.gather %base[%c0][%v], %mask, %pass_thru
: memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
- // 2-D memref gathered to 1-D vector.
- %1 = vector.gather %base[%i, %j][%v], %mask, %pass_thru
- : memref<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ // 2-D memref gathered to 1-D vector along two dimensions. The single
+ // index type is shared across all index vectors.
+ %1 = vector.gather %base[%i, %j][%v0, %v1], %mask, %pass_thru
+ : memref<16x16xf32>, vector<16xi32>, vector<16xi1>,
+ vector<16xf32> into vector<16xf32>
```
}];
let extraClassDeclaration = [{
ShapedType getBaseType() { return getBase().getType(); }
- VectorType getIndexVectorType() { return getIndices().getType(); }
+ VectorType getIndexVectorType() {
+ return cast<VectorType>(getIndices().front().getType());
+ }
VectorType getMaskVectorType() { return getMask().getType(); }
VectorType getPassThruVectorType() { return getPassThru().getType(); }
VectorType getVectorType() { return getResult().getType(); }
}];
- let assemblyFormat =
- "$base `[` $offsets `]` `[` $indices `]` `,` "
- "$mask `,` $pass_thru attr-dict `:` type($base) `,` "
- "type($indices) `,` type($mask) `,` type($pass_thru) "
- "`into` type($result)";
+ let assemblyFormat = [{
+ $base `[` $offsets `]` `[` $indices `]` `,`
+ $mask `,` $pass_thru attr-dict `:` type($base) `,`
+ custom<SameTypeVariadicOperands>(ref($indices), type($indices)) `,`
+ type($mask) `,` type($pass_thru) `into` type($result)
+ }];
let hasCanonicalizer = 1;
let hasVerifier = 1;
let builders = [
OpBuilder<(ins "VectorType":$resultType,
"Value":$base,
- "ValueRange":$indices,
- "Value":$index_vec,
+ "ValueRange":$offsets,
+ "ValueRange":$index_vecs,
"Value":$mask,
"Value":$passthrough,
CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment), [{
- return build($_builder, $_state, resultType, base, indices, index_vec, mask, passthrough,
+ return build($_builder, $_state, resultType, base, offsets, index_vecs,
+ mask, passthrough,
alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) :
nullptr);
}]>
@@ -2167,10 +2189,11 @@ def Vector_GatherOp :
def Vector_ScatterOp
: Vector_Op<"scatter",
[DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
- DeclareOpInterfaceMethods<AlignmentAttrOpInterface>]>,
+ DeclareOpInterfaceMethods<AlignmentAttrOpInterface>,
+ AttrSizedOperandSegments]>,
Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemWrite]>:$base,
Variadic<Index>:$offsets,
- VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
+ Vector_AtLeastOneIntegerVec:$indices,
VectorOfNonZeroRankOf<[I1]>:$mask,
AnyVectorOfNonZeroRank:$valueToStore,
OptionalAttr<IntValidAlignment<I64Attr>>:$alignment)>,
@@ -2183,12 +2206,20 @@ def Vector_ScatterOp
let description = [{
The scatter operation stores elements from a n-D vector into memory or ranked tensor as
- defined by a base with indices and an additional n-D index vector, but
- only if the corresponding bit in a n-D mask vector is set. Otherwise, no
- action is taken for that element. Informally the semantics are:
- ```
- if (mask[0]) base[index[0]] = value[0]
- if (mask[1]) base[index[1]] = value[1]
+ defined by a base with indices and an additional r (the scatter rank) n-D index
+ vectors, but only if the corresponding bit in a n-D mask vector is set.
+ Otherwise, no action is taken for that element. Informally the semantics are:
+ ```
+ if (mask[0, ..., 0]) {
+ base[offsets[0], ...,
+ offsets[k - r] + indices[0][0, 0, ... 0],
+ ...
+ offsets[k - 1] + indices[r - 1][0, 0, ... 0]] = value[0, 0 ..., 0]
+ }
+ if (mask[0, ..., 1]) {
+ base[offsets[0],
+ ...,
+ offsets[k - 1] + indices[r - 1][0, 0, ... 1]] = value[9, 0, ..., 1]
etc.
```
@@ -2208,6 +2239,16 @@ def Vector_ScatterOp
correspond to those of the `llvm.masked.scatter`
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-scatter-intrinsics).
+ If the type of the memref offsets is potentially larger than the type of the
+ scatter indices, it is assumed that the final `r` strides of the memref can be
+ truncated to the index type such that the linear scatter index can be computed
+ without signed overflow (i.e. the truncated stride sign-extends back to the
+ original value). If an extension to the offset type is needed, the scatter
+ indices are interpreted as signed integers.
+
+ Similarly, if the scatter indices are wider than the memref offset type,
+ they must be losslessly truncatable to that type.
+
An optional `alignment` attribute allows to specify the byte alignment of the
scatter operation. It must be a positive power of 2. The operation must access
memory at an address aligned to this boundary. Violating this requirement
@@ -2219,31 +2260,37 @@ def Vector_ScatterOp
vector.scatter %base[%c0][%v], %mask, %value
: memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
- vector.scatter %base[%i, %j][%v], %mask, %value
+ // The single index type is shared across all index vectors.
+ vector.scatter %base[%i, %j][%v0, %v1], %mask, %value
: memref<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
```
}];
let extraClassDeclaration = [{
ShapedType getBaseType() { return getBase().getType(); }
- VectorType getIndexVectorType() { return getIndices().getType(); }
+ VectorType getIndexVectorType() {
+ return cast<VectorType>(getIndices().front().getType());
+ }
VectorType getMaskVectorType() { return getMask().getType(); }
VectorType getVectorType() { return getValueToStore().getType(); }
}];
- let assemblyFormat = "$base `[` $offsets `]` `[` $indices `]` `,` "
- "$mask `,` $valueToStore attr-dict `:` type($base) `,` "
- "type($indices) `,` type($mask) `,` "
- "type($valueToStore) (`->` type($result)^)?";
+ let assemblyFormat = [{
+ $base `[` $offsets `]` `[` $indices `]` `,`
+ $mask `,` $valueToStore attr-dict `:` type($base) `,`
+ custom<SameTypeVariadicOperands>(ref($indices), type($indices)) `,`
+ type($mask) `,` type($valueToStore) (`->` type($result)^)?
+ }];
let hasCanonicalizer = 1;
let hasVerifier = 1;
let builders = [OpBuilder<
- (ins "Type":$resultType, "Value":$base, "ValueRange":$indices,
- "Value":$index_vec, "Value":$mask, "Value":$valueToStore,
+ (ins "Type":$resultType, "Value":$base, "ValueRange":$offsets,
+ "ValueRange":$index_vecs, "Value":$mask, "Value":$valueToStore,
CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment),
[{
- return build($_builder, $_state, resultType, base, indices, index_vec, mask, valueToStore,
+ return build($_builder, $_state, resultType, base, offsets, index_vecs,
+ mask, valueToStore,
alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) :
nullptr);
}]>];
@@ -2551,7 +2598,7 @@ def Vector_TypeCastOp :
}
def Vector_ConstantMaskOp :
- Vector_Op<"constant_mask", [Pure,
+ Vector_Op<"constant_mask", [Pure,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>
]>,
Arguments<(ins DenseI64ArrayAttr:$mask_dim_sizes)>,
@@ -2611,7 +2658,7 @@ def Vector_ConstantMaskOp :
}
def Vector_CreateMaskOp :
- Vector_Op<"create_mask", [Pure,
+ Vector_Op<"create_mask", [Pure,
DeclareOpInterfaceMethods<VectorUnrollOpInterface>
]>,
Arguments<(ins Variadic<Index>:$mask_dim_sizes)>,
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 43e0824fef6cd..87a2d79f08677 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -131,30 +131,75 @@ static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
return success();
}
-// Add an index vector component to a base pointer.
+// Add a sequence of index vectors componentwise to a base pointer, using the
+// strides from the given memref. The index vectors are linearized:
+// index = sum(strides.take_back(len(indices)) * indices)
static Value getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
const LLVMTypeConverter &typeConverter,
MemRefType memRefType, Value llvmMemref, Value base,
- Value index, VectorType vectorType) {
+ ValueRange indices, VectorType vectorType) {
assert(succeeded(isMemRefTypeSupported(memRefType, typeConverter)) &&
"unsupported memref type");
assert(vectorType.getRank() == 1 && "expected a 1-d vector type");
- auto pType = MemRefDescriptor(llvmMemref).getElementPtrType();
+
+ unsigned rank = memRefType.getRank();
+ MemRefDescriptor desc(llvmMemref);
+
+ int64_t numElems = vectorType.getDimSize(0);
+ bool isScalable = vectorType.getScalableDims()[0];
+ SmallVector<int32_t> zeroMask(numElems, 0);
+ Value i32Zero = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
+ rewriter.getI32IntegerAttr(0));
+
+ Value linearized;
+ auto idxElemType =
+ cast<VectorType>(indices.front().getType()).getElementType();
+ auto idxVecType = LLVM::getVectorType(idxElemType, numElems, isScalable);
+ unsigned r = indices.size();
+ for (auto [idx, stride] : llvm::zip_equal(
+ indices, llvm::map_range(llvm::seq(rank - r, rank), [&](unsigned d) {
+ return desc.stride(rewriter, loc, d);
+ }))) {
+
+ Value castStride = stride;
+ if (stride.getType() != idxElemType) {
+ unsigned strideBits = cast<IntegerType>(stride.getType()).getWidth();
+ unsigned idxBits = cast<IntegerType>(idxElemType).getWidth();
+ if (strideBits > idxBits) {
+ castStride = LLVM::TruncOp::create(rewriter, loc, idxElemType, stride,
+ LLVM::IntegerOverflowFlags::nsw);
+ } else {
+ castStride = LLVM::SExtOp::create(rewriter, loc, idxElemType, stride);
+ }
+ }
+ Value strideVec = [&]() {
+ Value poison = LLVM::PoisonOp::create(rewriter, loc, idxVecType);
+ Value inserted = LLVM::InsertElementOp::create(rewriter, loc, poison,
+ castStride, i32Zero);
+ return LLVM::ShuffleVectorOp::create(rewriter, loc, inserted, poison,
+ zeroMask);
+ }();
+ Value contribution = LLVM::MulOp::create(rewriter, loc, idx, strideVec);
+ linearized = linearized ? LLVM::AddOp::create(rewriter, loc, linearized,
+ contribution)
+ : contribution;
+ }
+
+ auto pType = desc.getElementPtrType();
auto ptrsType =
LLVM::getVectorType(pType, vectorType.getDimSize(0),
/*isScalable=*/vectorType.getScalableDims()[0]);
return LLVM::GEPOp::create(
rewriter, loc, ptrsType,
- typeConverter.convertType(memRefType.getElementType()), base, index);
+ typeConverter.convertType(memRefType.getElementType()), base, linearized);
}
-/// Convert `foldResult` into a Value. Integer attribute is converted to
-/// an LLVM constant op.
+/// Convert `foldResult` into a Value, using `llvm.mlir.constant` if needed.
static Value getAsLLVMValue(OpBuilder &builder, Location loc,
OpFoldResult foldResult) {
if (auto attr = dyn_cast<Attribute>(foldResult)) {
- auto intAttr = cast<IntegerAttr>(attr);
- return LLVM::ConstantOp::create(builder, loc, intAttr).getResult();
+ return LLVM::ConstantOp::create(builder, loc, cast<TypedAttr>(attr))
+ .getResult();
}
return cast<Value>(foldResult);
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 0aae5cb5bb6ad..afcd4caa7a323 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -379,22 +379,34 @@ static Value computeOffsets(PatternRewriter &rewriter, OpType gatScatOp,
baseOffset =
arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
}
- Value indices = gatScatOp.getIndices();
- VectorType vecType = cast<VectorType>(indices.getType());
- Value strideVector =
- vector::BroadcastOp::create(rewriter, loc, vecType, strides.back())
- .getResult();
- Value stridedIndices =
- arith::MulIOp::create(rewriter, loc, strideVector, indices).getResult();
+ // The op's index vectors may use a narrower element type (e.g. i32), but
+ // strides and the base vector are always index — cast everything to the
+ // index-typed vector before combining.
+ OperandRange indexVecs = gatScatOp.getIndices();
+ VectorType vecType = gatScatOp.getIndexVectorType();
+ auto indexVecType =
+ VectorType::get(vecType.getShape(), rewriter.getIndexType());
+ Value combinedIndices = arith::ConstantOp::create(
+ rewriter, loc,
+ DenseElementsAttr::get(indexVecType, rewriter.getIndexAttr(0)));
+
+ auto tailStrides = ArrayRef<Value>(strides).take_back(indexVecs.size());
+ for (auto [idx, stride] : llvm::zip_equal(indexVecs, tailStrides)) {
+ Value castIdx = idx;
+ if (vecType != indexVecType)
+ castIdx = arith::IndexCastOp::create(rewriter, loc, indexVecType, idx);
+ Value strideVec =
+ vector::BroadcastOp::create(rewriter, loc, indexVecType, stride);
+ Value stridedIdx = arith::MulIOp::create(rewriter, loc, strideVec, castIdx);
+ combinedIndices =
+ arith::AddIOp::create(rewriter, loc, combinedIndices, stridedIdx);
+ }
Value baseVector =
- vector::BroadcastOp::create(
- rewriter, loc,
- VectorType::get(vecType.getShape(), rewriter.getIndexType()),
- baseOffset)
+ vector::BroadcastOp::create(rewriter, loc, indexVecType, baseOffset)
.getResult();
- return arith::AddIOp::create(rewriter, loc, baseVector, stridedIndices)
+ return arith::AddIOp::create(rewriter, loc, baseVector, combinedIndices)
.getResult();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index b57e66a1c3580..a44ae49647f06 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -894,45 +894,21 @@ tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract) {
return success();
}
-/// Calculates the offsets (`$index_vec`) for `vector.gather` operations
-/// generated from `tensor.extract`. The offset is calculated as follows
-/// (example using scalar values):
-///
-/// offset = extractOp.indices[0]
-/// for (i = 1; i < numIndices; i++)
-/// offset = extractOp.dimSize[i] * offset + extractOp.indices[i];
-///
-/// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to:
-/// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3
-static Value calculateGatherOffset(RewriterBase &rewriter,
- ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/194395
More information about the Mlir-commits
mailing list