[Mlir-commits] [mlir] [mlir][vector] Rename gather/scatter arguments (nfc) (PR #153640)

Andrzej WarzyƄski llvmlistbot at llvm.org
Mon Aug 25 08:26:04 PDT 2025


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/153640

>From 798b079b7b36fe368949b98e34c1811c230ff58d Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 14 Aug 2025 18:19:22 +0000
Subject: [PATCH] [mlir][vector] Rename gather/scatter arguments

Renames `indices` as `offsets` and `index_vec` as `indices`.
---
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 30 +++++++++----------
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      |  8 ++---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 12 ++++----
 .../BufferizableOpInterfaceImpl.cpp           |  2 +-
 .../Vector/Transforms/LowerVectorGather.cpp   | 14 ++++-----
 .../Vector/Transforms/LowerVectorMask.cpp     |  2 +-
 .../Vector/Transforms/VectorUnroll.cpp        |  4 +--
 7 files changed, 36 insertions(+), 36 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index bcc423a634148..130ed9083848b 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2047,8 +2047,8 @@ def Vector_GatherOp :
     DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
   ]>,
     Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemRead]>:$base,
-               Variadic<Index>:$indices,
-               VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
+               Variadic<Index>:$offsets,
+               VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
                VectorOfNonZeroRankOf<[I1]>:$mask,
                AnyVectorOfNonZeroRank:$pass_thru,
                ConfinedAttr<OptionalAttr<I64Attr>,
@@ -2072,11 +2072,11 @@ def Vector_GatherOp :
 
     ```mlir
     func.func @gather_3D_to_2D(
-        %base: memref<?x10x?xf32>, %i0: index, %i1: index, %i2: index,
-        %index_vec: vector<2x3xi32>, %mask: vector<2x3xi1>,
+        %base: memref<?x10x?xf32>, %ofs_0: index, %ofs_1: index, %ofs_2: index,
+        %indices: vector<2x3xi32>, %mask: vector<2x3xi1>,
         %fall_thru: vector<2x3xf32>) -> vector<2x3xf32> {
-            %result = vector.gather %base[%i0, %i1, %i2]
-                                   [%index_vec], %mask, %fall_thru : [...]
+            %result = vector.gather %base[%ofs_0, %ofs_1, %ofs_2]
+                                   [%indices], %mask, %fall_thru : [...]
             return %result : vector<2x3xf32>
     }
     ```
@@ -2084,7 +2084,7 @@ def Vector_GatherOp :
     The indexing semantics are then,
 
     ```
-    result[i,j] := if mask[i,j] then base[i0, i1, i2 + index_vec[i,j]]
+    result[i,j] := if mask[i,j] then base[i0, i1, i2 + indices[i,j]]
                    else pass_thru[i,j]
     ```
     The index into `base` only varies in the innermost ((k-1)-th) dimension.
@@ -2118,16 +2118,16 @@ def Vector_GatherOp :
 
   let extraClassDeclaration = [{
     ShapedType getBaseType() { return getBase().getType(); }
-    VectorType getIndexVectorType() { return getIndexVec().getType(); }
+    VectorType getIndexVectorType() { return getIndices().getType(); }
     VectorType getMaskVectorType() { return getMask().getType(); }
     VectorType getPassThruVectorType() { return getPassThru().getType(); }
     VectorType getVectorType() { return getResult().getType(); }
   }];
 
   let assemblyFormat =
-    "$base `[` $indices `]` `[` $index_vec `]` `,` "
+    "$base `[` $offsets `]` `[` $indices `]` `,` "
     "$mask `,` $pass_thru attr-dict `:` type($base) `,` "
-    "type($index_vec)  `,` type($mask) `,` type($pass_thru) "
+    "type($indices)  `,` type($mask) `,` type($pass_thru) "
     "`into` type($result)";
   let hasCanonicalizer = 1;
   let hasVerifier = 1;
@@ -2150,8 +2150,8 @@ def Vector_GatherOp :
 def Vector_ScatterOp :
   Vector_Op<"scatter">,
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
-               Variadic<Index>:$indices,
-               VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
+               Variadic<Index>:$offsets,
+               VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
                VectorOfNonZeroRankOf<[I1]>:$mask,
                AnyVectorOfNonZeroRank:$valueToStore,
                ConfinedAttr<OptionalAttr<I64Attr>,
@@ -2207,15 +2207,15 @@ def Vector_ScatterOp :
 
   let extraClassDeclaration = [{
     MemRefType getMemRefType() { return getBase().getType(); }
-    VectorType getIndexVectorType() { return getIndexVec().getType(); }
+    VectorType getIndexVectorType() { return getIndices().getType(); }
     VectorType getMaskVectorType() { return getMask().getType(); }
     VectorType getVectorType() { return getValueToStore().getType(); }
   }];
 
   let assemblyFormat =
-      "$base `[` $indices `]` `[` $index_vec `]` `,` "
+      "$base `[` $offsets `]` `[` $indices `]` `,` "
       "$mask `,` $valueToStore attr-dict `:` type($base) `,` "
-      "type($index_vec)  `,` type($mask) `,` type($valueToStore)";
+      "type($indices)  `,` type($mask) `,` type($valueToStore)";
   let hasCanonicalizer = 1;
   let hasVerifier = 1;
 
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index afc3d1b12ac0d..1ff7d5dad378e 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -306,11 +306,11 @@ class VectorGatherOpConversion
 
     // Resolve address.
     Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
-                                     adaptor.getBase(), adaptor.getIndices());
+                                     adaptor.getBase(), adaptor.getOffsets());
     Value base = adaptor.getBase();
     Value ptrs =
         getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
-                       base, ptr, adaptor.getIndexVec(), vType);
+                       base, ptr, adaptor.getIndices(), vType);
 
     // Replace with the gather intrinsic.
     rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
@@ -362,10 +362,10 @@ class VectorScatterOpConversion
 
     // Resolve address.
     Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
-                                     adaptor.getBase(), adaptor.getIndices());
+                                     adaptor.getBase(), adaptor.getOffsets());
     Value ptrs =
         getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
-                       adaptor.getBase(), ptr, adaptor.getIndexVec(), vType);
+                       adaptor.getBase(), ptr, adaptor.getIndices(), vType);
 
     // Replace with the scatter intrinsic.
     rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2b2581d353673..bc93339a68ed3 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5782,7 +5782,7 @@ LogicalResult GatherOp::verify() {
 
   if (resVType.getElementType() != baseType.getElementType())
     return emitOpError("base and result element type should match");
-  if (llvm::size(getIndices()) != baseType.getRank())
+  if (llvm::size(getOffsets()) != baseType.getRank())
     return emitOpError("requires ") << baseType.getRank() << " indices";
   if (resVType.getShape() != indVType.getShape())
     return emitOpError("expected result dim to match indices dim");
@@ -5854,11 +5854,11 @@ class FoldContiguousGather final : public OpRewritePattern<GatherOp> {
     if (!isa<MemRefType>(op.getBase().getType()))
       return rewriter.notifyMatchFailure(op, "base must be of memref type");
 
-    if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
+    if (failed(isZeroBasedContiguousSeq(op.getIndices())))
       return failure();
 
     rewriter.replaceOpWithNewOp<MaskedLoadOp>(op, op.getType(), op.getBase(),
-                                              op.getIndices(), op.getMask(),
+                                              op.getOffsets(), op.getMask(),
                                               op.getPassThru());
     return success();
   }
@@ -5882,7 +5882,7 @@ LogicalResult ScatterOp::verify() {
 
   if (valueVType.getElementType() != memType.getElementType())
     return emitOpError("base and valueToStore element type should match");
-  if (llvm::size(getIndices()) != memType.getRank())
+  if (llvm::size(getOffsets()) != memType.getRank())
     return emitOpError("requires ") << memType.getRank() << " indices";
   if (valueVType.getShape() != indVType.getShape())
     return emitOpError("expected valueToStore dim to match indices dim");
@@ -5917,11 +5917,11 @@ class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(ScatterOp op,
                                 PatternRewriter &rewriter) const override {
-    if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
+    if (failed(isZeroBasedContiguousSeq(op.getIndices())))
       return failure();
 
     rewriter.replaceOpWithNewOp<MaskedStoreOp>(
-        op, op.getBase(), op.getIndices(), op.getMask(), op.getValueToStore());
+        op, op.getBase(), op.getOffsets(), op.getMask(), op.getValueToStore());
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
index 66196194b0585..546099ca975b7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -162,7 +162,7 @@ struct GatherOpInterface
       return failure();
     replaceOpWithNewBufferizedOp<vector::GatherOp>(
         rewriter, gatherOp, gatherOp.getVectorType(), *buffer,
-        gatherOp.getIndices(), gatherOp.getIndexVec(), gatherOp.getMask(),
+        gatherOp.getOffsets(), gatherOp.getIndices(), gatherOp.getMask(),
         gatherOp.getPassThru());
     return success();
   }
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index 90f21c53246b0..983018934a85c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -54,7 +54,7 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> {
 
   LogicalResult matchAndRewrite(vector::GatherOp op,
                                 PatternRewriter &rewriter) const override {
-    Value indexVec = op.getIndexVec();
+    Value indexVec = op.getIndices();
     Value maskVec = op.getMask();
     Value passThruVec = op.getPassThru();
 
@@ -69,7 +69,7 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> {
       Value passThruSubVec =
           vector::ExtractOp::create(rewriter, loc, passThruVec, thisIdx);
       return vector::GatherOp::create(rewriter, loc, subTy, op.getBase(),
-                                      op.getIndices(), indexSubVec, maskSubVec,
+                                      op.getOffsets(), indexSubVec, maskSubVec,
                                       passThruSubVec);
     };
 
@@ -141,18 +141,18 @@ struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
     // 2. Generate new gather indices that will model the
     // strided access.
     IntegerAttr stride = rewriter.getIndexAttr(srcTrailingDim);
-    VectorType vType = op.getIndexVec().getType();
+    VectorType vType = op.getIndices().getType();
     Value mulCst = arith::ConstantOp::create(
         rewriter, op.getLoc(), vType, DenseElementsAttr::get(vType, stride));
 
     Value newIdxs =
-        arith::MulIOp::create(rewriter, op.getLoc(), op.getIndexVec(), mulCst);
+        arith::MulIOp::create(rewriter, op.getLoc(), op.getIndices(), mulCst);
 
     // 3. Create an updated gather op with the collapsed input memref and the
     // updated indices.
     Value newGather = vector::GatherOp::create(
         rewriter, op.getLoc(), op.getResult().getType(), collapsed,
-        op.getIndices(), newIdxs, op.getMask(), op.getPassThru());
+        op.getOffsets(), newIdxs, op.getMask(), op.getPassThru());
     rewriter.replaceOp(op, newGather);
 
     return success();
@@ -195,8 +195,8 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
 
     Value indexVec = rewriter.createOrFold<arith::IndexCastOp>(
         loc, op.getIndexVectorType().clone(rewriter.getIndexType()),
-        op.getIndexVec());
-    auto baseOffsets = llvm::to_vector(op.getIndices());
+        op.getIndices());
+    auto baseOffsets = llvm::to_vector(op.getOffsets());
     Value lastBaseOffset = baseOffsets.back();
 
     Value result = op.getPassThru();
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index 45ef7f01a85f1..5617b067d249e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -269,7 +269,7 @@ struct MaskedGatherOpPattern : public MaskOpRewritePattern<GatherOp> {
     // Replace the `vector.mask` operation.
     rewriter.replaceOpWithNewOp<GatherOp>(
         maskingOp.getOperation(), gatherOp.getVectorType(), gatherOp.getBase(),
-        gatherOp.getIndices(), gatherOp.getIndexVec(), maskingOp.getMask(),
+        gatherOp.getOffsets(), gatherOp.getIndices(), maskingOp.getMask(),
         passthru);
     return success();
   }
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 501abecfacd04..e8ecb0c0be846 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -640,7 +640,7 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
       // decomposed shape from each of the index, mask, and pass-through
       // vectors.
       Value indexSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
-          loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides);
+          loc, gatherOp.getIndices(), elementOffsets, *targetShape, strides);
       Value maskSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
           loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
       Value passThruSubVec =
@@ -648,7 +648,7 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
               loc, gatherOp.getPassThru(), elementOffsets, *targetShape,
               strides);
       auto slicedGather = vector::GatherOp::create(
-          rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getIndices(),
+          rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getOffsets(),
           indexSubVec, maskSubVec, passThruSubVec);
 
       result = rewriter.createOrFold<vector::InsertStridedSliceOp>(



More information about the Mlir-commits mailing list