[Mlir-commits] [mlir] [mlir][vector] Rename gather/scatter arguments (nfc) (PR #153640)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Mon Aug 25 08:26:04 PDT 2025
https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/153640
>From 798b079b7b36fe368949b98e34c1811c230ff58d Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 14 Aug 2025 18:19:22 +0000
Subject: [PATCH] [mlir][vector] Rename gather/scatter arguments
Renames `indices` as `offsets` and `index_vec` as `indices`.
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 30 +++++++++----------
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 8 ++---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 12 ++++----
.../BufferizableOpInterfaceImpl.cpp | 2 +-
.../Vector/Transforms/LowerVectorGather.cpp | 14 ++++-----
.../Vector/Transforms/LowerVectorMask.cpp | 2 +-
.../Vector/Transforms/VectorUnroll.cpp | 4 +--
7 files changed, 36 insertions(+), 36 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index bcc423a634148..130ed9083848b 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2047,8 +2047,8 @@ def Vector_GatherOp :
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
]>,
Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemRead]>:$base,
- Variadic<Index>:$indices,
- VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
+ Variadic<Index>:$offsets,
+ VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
VectorOfNonZeroRankOf<[I1]>:$mask,
AnyVectorOfNonZeroRank:$pass_thru,
ConfinedAttr<OptionalAttr<I64Attr>,
@@ -2072,11 +2072,11 @@ def Vector_GatherOp :
```mlir
func.func @gather_3D_to_2D(
- %base: memref<?x10x?xf32>, %i0: index, %i1: index, %i2: index,
- %index_vec: vector<2x3xi32>, %mask: vector<2x3xi1>,
+ %base: memref<?x10x?xf32>, %ofs_0: index, %ofs_1: index, %ofs_2: index,
+ %indices: vector<2x3xi32>, %mask: vector<2x3xi1>,
%fall_thru: vector<2x3xf32>) -> vector<2x3xf32> {
- %result = vector.gather %base[%i0, %i1, %i2]
- [%index_vec], %mask, %fall_thru : [...]
+ %result = vector.gather %base[%ofs_0, %ofs_1, %ofs_2]
+ [%indices], %mask, %fall_thru : [...]
return %result : vector<2x3xf32>
}
```
@@ -2084,7 +2084,7 @@ def Vector_GatherOp :
The indexing semantics are then,
```
- result[i,j] := if mask[i,j] then base[i0, i1, i2 + index_vec[i,j]]
+ result[i,j] := if mask[i,j] then base[i0, i1, i2 + indices[i,j]]
else pass_thru[i,j]
```
The index into `base` only varies in the innermost ((k-1)-th) dimension.
@@ -2118,16 +2118,16 @@ def Vector_GatherOp :
let extraClassDeclaration = [{
ShapedType getBaseType() { return getBase().getType(); }
- VectorType getIndexVectorType() { return getIndexVec().getType(); }
+ VectorType getIndexVectorType() { return getIndices().getType(); }
VectorType getMaskVectorType() { return getMask().getType(); }
VectorType getPassThruVectorType() { return getPassThru().getType(); }
VectorType getVectorType() { return getResult().getType(); }
}];
let assemblyFormat =
- "$base `[` $indices `]` `[` $index_vec `]` `,` "
+ "$base `[` $offsets `]` `[` $indices `]` `,` "
"$mask `,` $pass_thru attr-dict `:` type($base) `,` "
- "type($index_vec) `,` type($mask) `,` type($pass_thru) "
+ "type($indices) `,` type($mask) `,` type($pass_thru) "
"`into` type($result)";
let hasCanonicalizer = 1;
let hasVerifier = 1;
@@ -2150,8 +2150,8 @@ def Vector_GatherOp :
def Vector_ScatterOp :
Vector_Op<"scatter">,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
- Variadic<Index>:$indices,
- VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
+ Variadic<Index>:$offsets,
+ VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
VectorOfNonZeroRankOf<[I1]>:$mask,
AnyVectorOfNonZeroRank:$valueToStore,
ConfinedAttr<OptionalAttr<I64Attr>,
@@ -2207,15 +2207,15 @@ def Vector_ScatterOp :
let extraClassDeclaration = [{
MemRefType getMemRefType() { return getBase().getType(); }
- VectorType getIndexVectorType() { return getIndexVec().getType(); }
+ VectorType getIndexVectorType() { return getIndices().getType(); }
VectorType getMaskVectorType() { return getMask().getType(); }
VectorType getVectorType() { return getValueToStore().getType(); }
}];
let assemblyFormat =
- "$base `[` $indices `]` `[` $index_vec `]` `,` "
+ "$base `[` $offsets `]` `[` $indices `]` `,` "
"$mask `,` $valueToStore attr-dict `:` type($base) `,` "
- "type($index_vec) `,` type($mask) `,` type($valueToStore)";
+ "type($indices) `,` type($mask) `,` type($valueToStore)";
let hasCanonicalizer = 1;
let hasVerifier = 1;
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index afc3d1b12ac0d..1ff7d5dad378e 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -306,11 +306,11 @@ class VectorGatherOpConversion
// Resolve address.
Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
- adaptor.getBase(), adaptor.getIndices());
+ adaptor.getBase(), adaptor.getOffsets());
Value base = adaptor.getBase();
Value ptrs =
getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
- base, ptr, adaptor.getIndexVec(), vType);
+ base, ptr, adaptor.getIndices(), vType);
// Replace with the gather intrinsic.
rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
@@ -362,10 +362,10 @@ class VectorScatterOpConversion
// Resolve address.
Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
- adaptor.getBase(), adaptor.getIndices());
+ adaptor.getBase(), adaptor.getOffsets());
Value ptrs =
getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
- adaptor.getBase(), ptr, adaptor.getIndexVec(), vType);
+ adaptor.getBase(), ptr, adaptor.getIndices(), vType);
// Replace with the scatter intrinsic.
rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2b2581d353673..bc93339a68ed3 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5782,7 +5782,7 @@ LogicalResult GatherOp::verify() {
if (resVType.getElementType() != baseType.getElementType())
return emitOpError("base and result element type should match");
- if (llvm::size(getIndices()) != baseType.getRank())
+ if (llvm::size(getOffsets()) != baseType.getRank())
return emitOpError("requires ") << baseType.getRank() << " indices";
if (resVType.getShape() != indVType.getShape())
return emitOpError("expected result dim to match indices dim");
@@ -5854,11 +5854,11 @@ class FoldContiguousGather final : public OpRewritePattern<GatherOp> {
if (!isa<MemRefType>(op.getBase().getType()))
return rewriter.notifyMatchFailure(op, "base must be of memref type");
- if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
+ if (failed(isZeroBasedContiguousSeq(op.getIndices())))
return failure();
rewriter.replaceOpWithNewOp<MaskedLoadOp>(op, op.getType(), op.getBase(),
- op.getIndices(), op.getMask(),
+ op.getOffsets(), op.getMask(),
op.getPassThru());
return success();
}
@@ -5882,7 +5882,7 @@ LogicalResult ScatterOp::verify() {
if (valueVType.getElementType() != memType.getElementType())
return emitOpError("base and valueToStore element type should match");
- if (llvm::size(getIndices()) != memType.getRank())
+ if (llvm::size(getOffsets()) != memType.getRank())
return emitOpError("requires ") << memType.getRank() << " indices";
if (valueVType.getShape() != indVType.getShape())
return emitOpError("expected valueToStore dim to match indices dim");
@@ -5917,11 +5917,11 @@ class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ScatterOp op,
PatternRewriter &rewriter) const override {
- if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
+ if (failed(isZeroBasedContiguousSeq(op.getIndices())))
return failure();
rewriter.replaceOpWithNewOp<MaskedStoreOp>(
- op, op.getBase(), op.getIndices(), op.getMask(), op.getValueToStore());
+ op, op.getBase(), op.getOffsets(), op.getMask(), op.getValueToStore());
return success();
}
};
diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
index 66196194b0585..546099ca975b7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -162,7 +162,7 @@ struct GatherOpInterface
return failure();
replaceOpWithNewBufferizedOp<vector::GatherOp>(
rewriter, gatherOp, gatherOp.getVectorType(), *buffer,
- gatherOp.getIndices(), gatherOp.getIndexVec(), gatherOp.getMask(),
+ gatherOp.getOffsets(), gatherOp.getIndices(), gatherOp.getMask(),
gatherOp.getPassThru());
return success();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index 90f21c53246b0..983018934a85c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -54,7 +54,7 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> {
LogicalResult matchAndRewrite(vector::GatherOp op,
PatternRewriter &rewriter) const override {
- Value indexVec = op.getIndexVec();
+ Value indexVec = op.getIndices();
Value maskVec = op.getMask();
Value passThruVec = op.getPassThru();
@@ -69,7 +69,7 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> {
Value passThruSubVec =
vector::ExtractOp::create(rewriter, loc, passThruVec, thisIdx);
return vector::GatherOp::create(rewriter, loc, subTy, op.getBase(),
- op.getIndices(), indexSubVec, maskSubVec,
+ op.getOffsets(), indexSubVec, maskSubVec,
passThruSubVec);
};
@@ -141,18 +141,18 @@ struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
// 2. Generate new gather indices that will model the
// strided access.
IntegerAttr stride = rewriter.getIndexAttr(srcTrailingDim);
- VectorType vType = op.getIndexVec().getType();
+ VectorType vType = op.getIndices().getType();
Value mulCst = arith::ConstantOp::create(
rewriter, op.getLoc(), vType, DenseElementsAttr::get(vType, stride));
Value newIdxs =
- arith::MulIOp::create(rewriter, op.getLoc(), op.getIndexVec(), mulCst);
+ arith::MulIOp::create(rewriter, op.getLoc(), op.getIndices(), mulCst);
// 3. Create an updated gather op with the collapsed input memref and the
// updated indices.
Value newGather = vector::GatherOp::create(
rewriter, op.getLoc(), op.getResult().getType(), collapsed,
- op.getIndices(), newIdxs, op.getMask(), op.getPassThru());
+ op.getOffsets(), newIdxs, op.getMask(), op.getPassThru());
rewriter.replaceOp(op, newGather);
return success();
@@ -195,8 +195,8 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
Value indexVec = rewriter.createOrFold<arith::IndexCastOp>(
loc, op.getIndexVectorType().clone(rewriter.getIndexType()),
- op.getIndexVec());
- auto baseOffsets = llvm::to_vector(op.getIndices());
+ op.getIndices());
+ auto baseOffsets = llvm::to_vector(op.getOffsets());
Value lastBaseOffset = baseOffsets.back();
Value result = op.getPassThru();
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index 45ef7f01a85f1..5617b067d249e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -269,7 +269,7 @@ struct MaskedGatherOpPattern : public MaskOpRewritePattern<GatherOp> {
// Replace the `vector.mask` operation.
rewriter.replaceOpWithNewOp<GatherOp>(
maskingOp.getOperation(), gatherOp.getVectorType(), gatherOp.getBase(),
- gatherOp.getIndices(), gatherOp.getIndexVec(), maskingOp.getMask(),
+ gatherOp.getOffsets(), gatherOp.getIndices(), maskingOp.getMask(),
passthru);
return success();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 501abecfacd04..e8ecb0c0be846 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -640,7 +640,7 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
// decomposed shape from each of the index, mask, and pass-through
// vectors.
Value indexSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
- loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides);
+ loc, gatherOp.getIndices(), elementOffsets, *targetShape, strides);
Value maskSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
Value passThruSubVec =
@@ -648,7 +648,7 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
loc, gatherOp.getPassThru(), elementOffsets, *targetShape,
strides);
auto slicedGather = vector::GatherOp::create(
- rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getIndices(),
+ rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getOffsets(),
indexSubVec, maskSubVec, passThruSubVec);
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
More information about the Mlir-commits
mailing list