[Mlir-commits] [mlir] c4617bc - [MLIR][XeGPU][VectorToXeGPU] Add lowering from vector.gather/scatter to xegpu.load/store (#158024)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Sep 19 02:12:17 PDT 2025


Author: Dmitry Chigarev
Date: 2025-09-19T11:12:14+02:00
New Revision: c4617bcae1308cf256bbd3738065eba2a4be8eb2

URL: https://github.com/llvm/llvm-project/commit/c4617bcae1308cf256bbd3738065eba2a4be8eb2
DIFF: https://github.com/llvm/llvm-project/commit/c4617bcae1308cf256bbd3738065eba2a4be8eb2.diff

LOG: [MLIR][XeGPU][VectorToXeGPU] Add lowering from vector.gather/scatter to xegpu.load/store (#158024)

Lowering for `vector.gather`/`vector.scatter` into `xegpu.load`/`xegpu.store`.

High level steps to lower vector.gather/scatter:
```
%0 = vector.gather %source[%off1, %off2, %off3][%indices], %mask,
       %pass_thru : memref<8x16x32xf32>, vector<8xindex>, vector<8xi1>, vector<8xf32> into vector<8xf32>
```

1. Compute strides and a memref offset for the `%source` memref using
`computeMemrefMeta` func from the transfer_read/write lowering
2. Compute a linear offset like `%lin_off = %base_offset + %off1 *
strides#0 + %off2 * strides#1 + %off3 * strides#2`
3. Combine the linear offset with `%indices`: `%off = (broadcast
%lin_off : index to vector<8xindex>) + %indices * strides#2`
4. Convert memref to an i64: `%flat_memref =
memref.extract_aligned_pointer_as_index %source + arith.index_cast`
5. Perform load/store: `%vec = xegpu.load %flat_memref[%off], %mask`
6. Apply selection to propagate values from the pass_thru vector: `%res
= arith.select %mask, %vec, %pass_thru`

Added: 
    mlir/test/Conversion/VectorToXeGPU/gather-to-xegpu.mlir
    mlir/test/Conversion/VectorToXeGPU/scatter-to-xegpu.mlir

Modified: 
    mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 852c322cc6467..9f5585a701438 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -183,11 +183,15 @@ static void adjustStridesForPermutation(AffineMap permMap,
 // Computes memory strides and a memref offset for vector transfer operations,
 // handling both static and dynamic memrefs while applying permutation
 // transformations for XeGPU lowering.
+template <
+    typename OpType,
+    typename = std::enable_if_t<llvm::is_one_of<
+        std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
+        vector::GatherOp, vector::ScatterOp>::value>>
 static std::pair<SmallVector<Value>, Value>
-computeMemrefMeta(VectorTransferOpInterface xferOp, PatternRewriter &rewriter) {
+computeMemrefMeta(OpType xferOp, PatternRewriter &rewriter) {
   SmallVector<Value> strides;
   Value baseMemref = xferOp.getBase();
-  AffineMap permMap = xferOp.getPermutationMap();
   MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
 
   Location loc = xferOp.getLoc();
@@ -197,9 +201,14 @@ computeMemrefMeta(VectorTransferOpInterface xferOp, PatternRewriter &rewriter) {
     SmallVector<int64_t> intStrides;
     if (failed(memrefType.getStridesAndOffset(intStrides, offset)))
       return {{}, offsetVal};
-    // Wrap static strides as MLIR values
-    for (int64_t s : intStrides)
-      strides.push_back(arith::ConstantIndexOp::create(rewriter, loc, s));
+    bool hasDynamicStrides = llvm::any_of(intStrides, [](int64_t strideVal) {
+      return ShapedType::isDynamic(strideVal);
+    });
+
+    if (!hasDynamicStrides)
+      for (int64_t s : intStrides)
+        strides.push_back(arith::ConstantIndexOp::create(rewriter, loc, s));
+
     if (!ShapedType::isDynamic(offset))
       offsetVal = arith::ConstantIndexOp::create(rewriter, loc, offset);
   }
@@ -232,8 +241,14 @@ computeMemrefMeta(VectorTransferOpInterface xferOp, PatternRewriter &rewriter) {
     if (!offsetVal)
       offsetVal = meta.getOffset();
   }
-  // Adjust strides according to the permutation map (e.g., for transpose)
-  adjustStridesForPermutation(permMap, strides);
+
+  if constexpr (llvm::is_one_of<std::decay_t<OpType>, vector::TransferReadOp,
+                                vector::TransferWriteOp>::value) {
+    AffineMap permMap = xferOp.getPermutationMap();
+    // Adjust strides according to the permutation map (e.g., for transpose)
+    adjustStridesForPermutation(permMap, strides);
+  }
+
   return {strides, offsetVal};
 }
 
@@ -339,9 +354,51 @@ static Value computeOffsets(VectorTransferOpInterface xferOp,
   return localOffsets;
 }
 
+// Compute the element-wise offsets for vector.gather or vector.scatter ops.
+//
+// This function linearizes the base offsets of the gather/scatter operation
+// and combines them with the per-element indices to produce a final vector of
+// memory offsets.
+template <
+    typename OpType,
+    typename = std::enable_if_t<llvm::is_one_of<
+        std::decay_t<OpType>, vector::GatherOp, vector::ScatterOp>::value>>
+static Value computeOffsets(PatternRewriter &rewriter, OpType gatScatOp,
+                            ArrayRef<Value> strides, Value baseOffset) {
+  Location loc = gatScatOp.getLoc();
+  SmallVector<Value> offsets = gatScatOp.getOffsets();
+  for (size_t i = 0; i < offsets.size(); ++i) {
+    Value offsetContrib =
+        arith::MulIOp::create(rewriter, loc, offsets[i], strides[i]);
+    baseOffset =
+        arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
+  }
+  Value indices = gatScatOp.getIndices();
+  VectorType vecType = cast<VectorType>(indices.getType());
+
+  Value strideVector =
+      vector::BroadcastOp::create(rewriter, loc, vecType, strides.back())
+          .getResult();
+  Value stridedIndices =
+      arith::MulIOp::create(rewriter, loc, strideVector, indices).getResult();
+
+  Value baseVector =
+      vector::BroadcastOp::create(
+          rewriter, loc,
+          VectorType::get(vecType.getShape(), rewriter.getIndexType()),
+          baseOffset)
+          .getResult();
+  return arith::AddIOp::create(rewriter, loc, baseVector, stridedIndices)
+      .getResult();
+}
+
+template <
+    typename OpType,
+    typename = std::enable_if_t<llvm::is_one_of<
+        std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
+        vector::GatherOp, vector::ScatterOp>::value>>
 // Convert memref to i64 base pointer
-static Value memrefToIndexPtr(VectorTransferOpInterface xferOp,
-                              PatternRewriter &rewriter) {
+static Value memrefToIndexPtr(OpType xferOp, PatternRewriter &rewriter) {
   Location loc = xferOp.getLoc();
   auto indexPtr = memref::ExtractAlignedPointerAsIndexOp::create(
                       rewriter, loc, xferOp.getBase())
@@ -539,6 +596,71 @@ struct TransferWriteLowering
   }
 };
 
+struct GatherLowering : public OpRewritePattern<vector::GatherOp> {
+  using OpRewritePattern<vector::GatherOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
+                                PatternRewriter &rewriter) const override {
+    auto srcTy = dyn_cast<MemRefType>(gatherOp.getBase().getType());
+    if (!srcTy)
+      return rewriter.notifyMatchFailure(gatherOp, "Expects memref source");
+
+    Location loc = gatherOp.getLoc();
+    VectorType vectorType = gatherOp.getVectorType();
+
+    auto meta = computeMemrefMeta(gatherOp, rewriter);
+    if (meta.first.empty())
+      return rewriter.notifyMatchFailure(gatherOp, "Failed to compute strides");
+
+    Value localOffsets =
+        computeOffsets(rewriter, gatherOp, meta.first, meta.second);
+    Value flatMemref = memrefToIndexPtr(gatherOp, rewriter);
+
+    auto xeGatherOp = xegpu::LoadGatherOp::create(
+        rewriter, loc, vectorType, flatMemref, localOffsets, gatherOp.getMask(),
+        /*chunk_size=*/IntegerAttr{},
+        /*l1_hint=*/xegpu::CachePolicyAttr{},
+        /*l2_hint=*/xegpu::CachePolicyAttr{},
+        /*l3_hint=*/xegpu::CachePolicyAttr{});
+
+    auto selectOp =
+        arith::SelectOp::create(rewriter, loc, gatherOp.getMask(),
+                                xeGatherOp.getResult(), gatherOp.getPassThru());
+    rewriter.replaceOp(gatherOp, selectOp.getResult());
+    return success();
+  }
+};
+
+struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> {
+  using OpRewritePattern<vector::ScatterOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ScatterOp scatterOp,
+                                PatternRewriter &rewriter) const override {
+    auto srcTy = dyn_cast<MemRefType>(scatterOp.getBase().getType());
+    if (!srcTy)
+      return rewriter.notifyMatchFailure(scatterOp, "Expects memref source");
+
+    Location loc = scatterOp.getLoc();
+    auto meta = computeMemrefMeta(scatterOp, rewriter);
+    if (meta.first.empty())
+      return rewriter.notifyMatchFailure(scatterOp,
+                                         "Failed to compute strides");
+
+    Value localOffsets =
+        computeOffsets(rewriter, scatterOp, meta.first, meta.second);
+    Value flatMemref = memrefToIndexPtr(scatterOp, rewriter);
+
+    xegpu::StoreScatterOp::create(rewriter, loc, scatterOp.getValueToStore(),
+                                  flatMemref, localOffsets, scatterOp.getMask(),
+                                  /*chunk_size=*/IntegerAttr{},
+                                  /*l1_hint=*/xegpu::CachePolicyAttr{},
+                                  /*l2_hint=*/xegpu::CachePolicyAttr{},
+                                  /*l3_hint=*/xegpu::CachePolicyAttr{});
+    rewriter.eraseOp(scatterOp);
+    return success();
+  }
+};
+
 struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
   using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
 
@@ -654,6 +776,8 @@ struct ConvertVectorToXeGPUPass
 
 void mlir::populateVectorToXeGPUConversionPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<TransferReadLowering, TransferWriteLowering, LoadLowering,
-               StoreLowering, ContractionLowering>(patterns.getContext());
+  patterns
+      .add<TransferReadLowering, TransferWriteLowering, LoadLowering,
+           ScatterLowering, GatherLowering, StoreLowering, ContractionLowering>(
+          patterns.getContext());
 }

diff  --git a/mlir/test/Conversion/VectorToXeGPU/gather-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/gather-to-xegpu.mlir
new file mode 100644
index 0000000000000..2a319869a7b06
--- /dev/null
+++ b/mlir/test/Conversion/VectorToXeGPU/gather-to-xegpu.mlir
@@ -0,0 +1,251 @@
+// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s
+
+gpu.module @xevm_module {
+gpu.func @load_1D_vector(%source: memref<8x16x32xf32>,
+     %off1: index, %off2: index, %off3: index,
+     %indices: vector<8xindex>, %mask: vector<8xi1>,
+     %pass_thru: vector<8xf32>) -> vector<8xf32> {
+  %0 = vector.gather %source[%off1, %off2, %off3][%indices], %mask,
+       %pass_thru : memref<8x16x32xf32>, vector<8xindex>, vector<8xi1>, vector<8xf32> into vector<8xf32>
+  gpu.return %0 : vector<8xf32>
+}
+// CHECK-LABEL:  @load_1D_vector(
+// CHECK-SAME:   %[[SRC:.+]]: memref<8x16x32xf32>,
+// CHECK-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
+// CHECK-SAME:   %[[INDICES:.+]]: vector<8xindex>
+// CHECK-SAME:   %[[MASK:.+]]: vector<8xi1>
+// CHECK-SAME:   %[[PASS_THRU:.+]]: vector<8xf32>) -> vector<8xf32> {
+// CHECK-COUNT2: arith.muli {{.*}} : index
+// CHECK-COUNT2: arith.addi {{.*}} : index
+// CHECK:        %[[SPLAT:.+]] = vector.broadcast {{.*}}:  index to vector<8xindex>
+// CHECK:        %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex>
+// CHECK:        %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index
+// CHECK:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// CHECK:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf32>
+// CHECK:        %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8xi1>, vector<8xf32>
+// CHECK:        gpu.return %[[RES]] : vector<8xf32>
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @load_2D_memref(%source: memref<8x32xf32>,
+     %off1: index, %off2: index,
+     %indices: vector<8xindex>, %mask: vector<8xi1>,
+     %pass_thru: vector<8xf32>) -> vector<8xf32> {
+  %0 = vector.gather %source[%off1, %off2][%indices], %mask,
+       %pass_thru : memref<8x32xf32>, vector<8xindex>, vector<8xi1>, vector<8xf32> into vector<8xf32>
+  gpu.return %0 : vector<8xf32>
+}
+// CHECK-LABEL:  @load_2D_memref(
+// CHECK-SAME:   %[[SRC:.+]]: memref<8x32xf32>,
+// CHECK-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// CHECK-SAME:   %[[INDICES:.+]]: vector<8xindex>
+// CHECK-SAME:   %[[MASK:.+]]: vector<8xi1>
+// CHECK-SAME:   %[[PASS_THRU:.+]]: vector<8xf32>) -> vector<8xf32> {
+// CHECK-COUNT1: arith.muli {{.*}} : index
+// CHECK-COUNT1: arith.addi {{.*}} : index
+// CHECK:        %[[SPLAT:.+]] = vector.broadcast {{.*}}:  index to vector<8xindex>
+// CHECK:        %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex>
+// CHECK:        %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x32xf32> -> index
+// CHECK:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// CHECK:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf32>
+// CHECK:        %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8xi1>, vector<8xf32>
+// CHECK:        gpu.return %[[RES]] : vector<8xf32>
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @load_2D_vector(%source: memref<8x16x32xf32>,
+    %off0: index, %off1: index, %off2: index,
+    %indices: vector<8x16xindex>, %mask: vector<8x16xi1>,
+    %pass_thru: vector<8x16xf32>) -> vector<8x16xf32> {
+  %0 = vector.gather %source[%off0, %off1, %off2][%indices], %mask,
+       %pass_thru : memref<8x16x32xf32>, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32> into vector<8x16xf32>
+  gpu.return %0 : vector<8x16xf32>
+}
+// CHECK-LABEL:  @load_2D_vector(
+// CHECK-SAME:   %[[SRC:.+]]: memref<8x16x32xf32>,
+// CHECK-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
+// CHECK-SAME:   %[[INDICES:.+]]: vector<8x16xindex>
+// CHECK-SAME:   %[[MASK:.+]]: vector<8x16xi1>
+// CHECK-SAME:   %[[PASS_THRU:.+]]: vector<8x16xf32>) -> vector<8x16xf32> {
+// CHECK-COUNT2: arith.muli {{.*}} : index
+// CHECK-COUNT2: arith.addi {{.*}} : index
+// CHECK:        %[[SPLAT:.+]] = vector.broadcast {{.*}}:  index to vector<8x16xindex>
+// CHECK:        %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8x16xindex>
+// CHECK:        %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index
+// CHECK:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// CHECK:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+// CHECK:        %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8x16xi1>, vector<8x16xf32>
+// CHECK:        gpu.return %[[RES]] : vector<8x16xf32>
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
+    %off0: index, %off1: index, %off2: index,
+    %indices: vector<8x16xindex>, %mask: vector<8x16xi1>,
+    %pass_thru: vector<8x16xf32>) -> vector<8x16xf32> {
+  %0 = vector.gather %source[%off0, %off1, %off2][%indices], %mask,
+       %pass_thru : memref<?x?x?xf32>, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32> into vector<8x16xf32>
+  gpu.return %0 : vector<8x16xf32>
+}
+// CHECK-LABEL:  @load_dynamic_source(
+// CHECK-SAME:   %[[SRC:.+]]: memref<?x?x?xf32>,
+// CHECK-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
+// CHECK-SAME:   %[[INDICES:.+]]: vector<8x16xindex>
+// CHECK-SAME:   %[[MASK:.+]]: vector<8x16xi1>
+// CHECK-SAME:   %[[PASS_THRU:.+]]: vector<8x16xf32>) -> vector<8x16xf32> {
+// CHECK:        memref.extract_strided_metadata %[[SRC]]
+// CHECK-COUNT2: arith.muli {{.*}} : index
+// CHECK-COUNT2: arith.addi {{.*}} : index
+// CHECK:        %[[SPLAT:.+]] = vector.broadcast {{.*}}:  index to vector<8x16xindex>
+// CHECK:        %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8x16xindex>
+// CHECK:        %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<?x?x?xf32> -> index
+// CHECK:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// CHECK:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+// CHECK:        %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8x16xi1>, vector<8x16xf32>
+// CHECK:        gpu.return %[[RES]] : vector<8x16xf32>
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
+    %off0: index, %off1: index, %off2: index,
+    %indices: vector<8x16xindex>, %mask: vector<8x16xi1>,
+    %pass_thru: vector<8x16xf32>) -> vector<8x16xf32> {
+  %0 = vector.gather %source[%off0, %off1, %off2][%indices], %mask,
+       %pass_thru : memref<?x8x16xf32>, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32> into vector<8x16xf32>
+  gpu.return %0 : vector<8x16xf32>
+}
+// CHECK-LABEL:  @load_dynamic_source2(
+// CHECK-SAME:   %[[SRC:.+]]: memref<?x8x16xf32>,
+// CHECK-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
+// CHECK-SAME:   %[[INDICES:.+]]: vector<8x16xindex>
+// CHECK-SAME:   %[[MASK:.+]]: vector<8x16xi1>
+// CHECK-SAME:   %[[PASS_THRU:.+]]: vector<8x16xf32>) -> vector<8x16xf32> {
+// CHECK-NOT:    memref.extract_strided_metadata %[[SRC]]
+// CHECK-COUNT2: arith.muli {{.*}} : index
+// CHECK-COUNT2: arith.addi {{.*}} : index
+// CHECK:        %[[SPLAT:.+]] = vector.broadcast {{.*}}:  index to vector<8x16xindex>
+// CHECK:        %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8x16xindex>
+// CHECK:        %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<?x8x16xf32> -> index
+// CHECK:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// CHECK:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+// CHECK:        %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8x16xi1>, vector<8x16xf32>
+// CHECK:        gpu.return %[[RES]] : vector<8x16xf32>
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @no_load_tensor(%source: tensor<32x64xf32>,
+    %off: index, %indices: vector<8x16xindex>,
+    %mask: vector<8x16xi1>, %pass_thru: vector<8x16xf32>) -> vector<8x16xf32> {
+  %0 = vector.gather %source[%off, %off][%indices], %mask,
+       %pass_thru : tensor<32x64xf32>, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32> into vector<8x16xf32>
+  gpu.return %0 : vector<8x16xf32>
+}
+// CHECK-LABEL:  @no_load_tensor(
+// CHECK:        vector.gather
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @gather_from_subview(%source: memref<4096x4096xf16>,
+                              %memref_off: index, %off1: index, %off2: index,
+                              %indices: vector<8xindex>,
+                              %mask: vector<8xi1>,
+                              %pass_thru: vector<8xf16>) -> vector<8xf16> {
+  %subview = memref.subview %source[%memref_off, %memref_off] [256, 256] [1, 1]
+      : memref<4096x4096xf16>
+        to memref<256x256xf16, strided<[4096, 1], offset: ?>>
+  %0 = vector.gather %subview[%off1, %off2][%indices], %mask, %pass_thru
+       : memref<256x256xf16, strided<[4096, 1], offset: ?>>,
+         vector<8xindex>, vector<8xi1>, vector<8xf16>
+         into vector<8xf16>
+  gpu.return %0 : vector<8xf16>
+}
+// CHECK-LABEL:  @gather_from_subview(
+// CHECK-SAME:   %[[SRC:.+]]: memref<4096x4096xf16>,
+// CHECK-SAME:   %[[MEMREF_OFF:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index,
+// CHECK-SAME:   %[[INDICES:.+]]: vector<8xindex>,
+// CHECK-SAME:   %[[MASK:.+]]: vector<8xi1>,
+// CHECK-SAME:   %[[PASS:.+]]: vector<8xf16>) -> vector<8xf16> {
+// CHECK:        %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[MEMREF_OFF]], %[[MEMREF_OFF]]] [256, 256] [1, 1]
+// CHECK:        %[[BB:.+]], %[[OFFSET:.+]],{{.*}},{{.*}} = memref.extract_strided_metadata %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> memref<f16>, index, index, index, index, index
+// CHECK:        arith.muli {{.*}}%[[OFF1]]{{.*}} : index
+// CHECK:        arith.addi %[[OFFSET]]{{.*}} : index
+// CHECK:        %[[BASE_OFF:.+]] = arith.addi {{.*}}%[[OFF2]]{{.*}} : index
+// CHECK:        %[[SPLAT:.+]] = vector.broadcast %[[BASE_OFF]] : index to vector<8xindex>
+// CHECK:        %[[LIN:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex>
+// CHECK:        %[[BASE_IDX:.+]] = memref.extract_aligned_pointer_as_index %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> index
+// CHECK:        %[[BASE_I64:.+]] = arith.index_cast %[[BASE_IDX]] : index to i64
+// CHECK:        %[[VEC:.+]] = xegpu.load %[[BASE_I64]]{{\[}}%[[LIN]]{{\]}}, %[[MASK]]
+// CHECK-SAME:     : i64, vector<8xindex>, vector<8xi1> -> vector<8xf16>
+// CHECK:        %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS]] : vector<8xi1>, vector<8xf16>
+// CHECK:        gpu.return %[[RES]] : vector<8xf16>
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @non_unit_inner_stride_1D(
+    %source: memref<32xf32, strided<[?], offset: ?>>,
+    %off: index, %indices: vector<8xindex>, %mask: vector<8xi1>,
+    %pass_thru: vector<8xf32>) -> vector<8xf32> {
+  %0 = vector.gather %source[%off][%indices], %mask, %pass_thru
+       : memref<32xf32, strided<[?], offset: ?>>,
+         vector<8xindex>, vector<8xi1>, vector<8xf32>
+         into vector<8xf32>
+  gpu.return %0 : vector<8xf32>
+}
+// CHECK-LABEL:  @non_unit_inner_stride_1D(
+// CHECK-SAME:   %[[SRC:.+]]: memref<32xf32, strided<[?], offset: ?>>,
+// CHECK-SAME:   %[[OFF1:.+]]: index,
+// CHECK-SAME:   %[[INDICES:.+]]: vector<8xindex>,
+// CHECK-SAME:   %[[MASK:.+]]: vector<8xi1>, %[[PASS:.+]]: vector<8xf32>) -> vector<8xf32> {
+// CHECK:        %[[BB:.+]], %[[M_OFF:.+]], %[[SZ:.+]], %[[STRIDE:.+]] = memref.extract_strided_metadata %[[SRC]]
+// CHECK:        arith.muli %[[OFF1]], %[[STRIDE]] : index
+// CHECK:        arith.addi {{.*}} : index
+// CHECK:        %[[STRD_VEC:.+]] = vector.broadcast %[[STRIDE]] : index to vector<8xindex>
+// CHECK:        %[[STRD_INDICES:.+]] = arith.muli %[[STRD_VEC:.+]], %[[INDICES]] : vector<8xindex>
+// CHECK:        %[[SPLAT:.+]] = vector.broadcast {{.*}}:  index to vector<8xindex>
+// CHECK:        %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[STRD_INDICES]] : vector<8xindex>
+// CHECK:        %[[BASE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<32xf32, strided<[?], offset: ?>> -> index
+// CHECK:        %[[BASE_I64:.+]] = arith.index_cast %[[BASE]] : index to i64
+// CHECK:        %[[V:.+]] = xegpu.load %[[BASE_I64]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf32>
+// CHECK:        %[[RES:.+]] = arith.select %[[MASK]], %[[V]], %[[PASS]] : vector<8xi1>, vector<8xf32>
+// CHECK:        gpu.return %[[RES]] : vector<8xf32>
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @non_unit_inner_stride_3D(
+    %source: memref<4x8x32xf32, strided<[?, 128, 2], offset: ?>>,
+    %off0: index, %off1: index, %off2: index,
+    %indices: vector<8xindex>, %mask: vector<8xi1>,
+    %pass_thru: vector<8xf32>) -> vector<8xf32> {
+  %0 = vector.gather %source[%off0, %off1, %off2][%indices], %mask, %pass_thru
+       : memref<4x8x32xf32, strided<[?, 128, 2], offset: ?>>,
+         vector<8xindex>, vector<8xi1>, vector<8xf32>
+         into vector<8xf32>
+  gpu.return %0 : vector<8xf32>
+}
+// CHECK-LABEL:  @non_unit_inner_stride_3D(
+// CHECK-SAME:   %[[SRC:.+]]: memref<4x8x32xf32, strided<[?, 128, 2], offset: ?>>,
+// CHECK-SAME:   %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index,
+// CHECK-SAME:   %[[INDICES:.+]]: vector<8xindex>, %[[MASK:.+]]: vector<8xi1>,
+// CHECK-SAME:   %[[PASS:.+]]: vector<8xf32>) -> vector<8xf32> {
+// CHECK:        %[[BB:.+]], %[[M_OFF:.+]], %[[SIZES:.+]]:3, %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
+// CHECK:        arith.muli %[[OFF0]], %[[STRIDES]]#0 : index
+// CHECK:        arith.addi {{.*}} : index
+// CHECK-COUNT2: arith.muli {{.*}} : index
+// CHECK-COUNT2: arith.addi {{.*}} : index
+// CHECK:        %[[STRD_INDICES:.+]] = arith.muli {{.*}}%[[INDICES]]{{.*}} : vector<8xindex>
+// CHECK:        %[[SPLAT:.+]] = vector.broadcast {{.*}} : index to vector<8xindex>
+// CHECK:        %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[STRD_INDICES]] : vector<8xindex>
+// CHECK:        %[[BASE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<4x8x32xf32, strided<[?, 128, 2], offset: ?>> -> index
+// CHECK:        %[[BASE_I64:.+]] = arith.index_cast %[[BASE]] : index to i64
+// CHECK:        %[[V:.+]] = xegpu.load %[[BASE_I64]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf32>
+// CHECK:        %[[RES:.+]] = arith.select %[[MASK]], %[[V]], %[[PASS]] : vector<8xi1>, vector<8xf32>
+// CHECK:        gpu.return %[[RES]] : vector<8xf32>
+}

diff  --git a/mlir/test/Conversion/VectorToXeGPU/scatter-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/scatter-to-xegpu.mlir
new file mode 100644
index 0000000000000..ffd3f170c0fad
--- /dev/null
+++ b/mlir/test/Conversion/VectorToXeGPU/scatter-to-xegpu.mlir
@@ -0,0 +1,206 @@
+// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s
+
+gpu.module @xevm_module {
+gpu.func @store_1D_vector(%vec: vector<8xf32>, %source: memref<8x16x32xf32>,
+     %off1: index, %off2: index, %off3: index,
+     %indices: vector<8xindex>, %mask: vector<8xi1>) {
+  vector.scatter %source[%off1, %off2, %off3][%indices], %mask, %vec
+       : memref<8x16x32xf32>, vector<8xindex>, vector<8xi1>, vector<8xf32>
+  gpu.return
+}
+// CHECK-LABEL:  @store_1D_vector(
+// CHECK-SAME:   %[[VAL:.+]]: vector<8xf32>, %[[SRC:.+]]: memref<8x16x32xf32>,
+// CHECK-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
+// CHECK-SAME:   %[[INDICES:.+]]: vector<8xindex>, %[[MASK:.+]]: vector<8xi1>) {
+// CHECK-COUNT2: arith.muli {{.*}} : index
+// CHECK-COUNT2: arith.addi {{.*}} : index
+// CHECK:        %[[SPLAT:.+]] = vector.broadcast {{.*}}:  index to vector<8xindex>
+// CHECK:        %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex>
+// CHECK:        %[[BASE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index
+// CHECK:        %[[BASE_I64:.+]] = arith.index_cast %[[BASE]] : index to i64
+// CHECK:        xegpu.store %[[VAL]], %[[BASE_I64]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : vector<8xf32>, i64, vector<8xindex>, vector<8xi1>
+// CHECK:        gpu.return
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @store_2D_memref(%vec: vector<8xf32>, %source: memref<8x32xf32>,
+     %off1: index, %off2: index,
+     %indices: vector<8xindex>, %mask: vector<8xi1>) {
+  vector.scatter %source[%off1, %off2][%indices], %mask, %vec
+       : memref<8x32xf32>, vector<8xindex>, vector<8xi1>, vector<8xf32>
+  gpu.return
+}
+// CHECK-LABEL:  @store_2D_memref(
+// CHECK-SAME:   %[[VAL:.+]]: vector<8xf32>, %[[SRC:.+]]: memref<8x32xf32>,
+// CHECK-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// CHECK-SAME:   %[[INDICES:.+]]: vector<8xindex>, %[[MASK:.+]]: vector<8xi1>) {
+// CHECK-COUNT1: arith.muli {{.*}} : index
+// CHECK-COUNT1: arith.addi {{.*}} : index
+// CHECK:        %[[SPLAT:.+]] = vector.broadcast {{.*}}:  index to vector<8xindex>
+// CHECK:        %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex>
+// CHECK:        %[[BASE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x32xf32> -> index
+// CHECK:        %[[BASE_I64:.+]] = arith.index_cast %[[BASE]] : index to i64
+// CHECK:        xegpu.store %[[VAL]], %[[BASE_I64]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : vector<8xf32>, i64, vector<8xindex>, vector<8xi1>
+// CHECK:        gpu.return
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @store_2D_vector(%vec: vector<8x16xf32>, %source: memref<8x16x32xf32>,
+    %off0: index, %off1: index, %off2: index,
+    %indices: vector<8x16xindex>, %mask: vector<8x16xi1>) {
+  vector.scatter %source[%off0, %off1, %off2][%indices], %mask, %vec
+       : memref<8x16x32xf32>, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32>
+  gpu.return
+}
+// CHECK-LABEL:  @store_2D_vector(
+// CHECK-SAME:   %[[VAL:.+]]: vector<8x16xf32>, %[[SRC:.+]]: memref<8x16x32xf32>,
+// CHECK-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
+// CHECK-SAME:   %[[INDICES:.+]]: vector<8x16xindex>, %[[MASK:.+]]: vector<8x16xi1>) {
+// CHECK-COUNT2: arith.muli {{.*}} : index
+// CHECK-COUNT2: arith.addi {{.*}} : index
+// CHECK:        %[[SPLAT:.+]] = vector.broadcast {{.*}}:  index to vector<8x16xindex>
+// CHECK:        %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8x16xindex>
+// CHECK:        %[[BASE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index
+// CHECK:        %[[BASE_I64:.+]] = arith.index_cast %[[BASE]] : index to i64
+// CHECK:        xegpu.store %[[VAL]], %[[BASE_I64]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : vector<8x16xf32>, i64, vector<8x16xindex>, vector<8x16xi1>
+// CHECK:        gpu.return
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @store_dynamic_source(%vec: vector<8x16xf32>, %source: memref<?x?x?xf32>,
+    %off0: index, %off1: index, %off2: index,
+    %indices: vector<8x16xindex>, %mask: vector<8x16xi1>) {
+  vector.scatter %source[%off0, %off1, %off2][%indices], %mask, %vec
+       : memref<?x?x?xf32>, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32>
+  gpu.return
+}
+// CHECK-LABEL:  @store_dynamic_source(
+// CHECK-SAME:   %[[VAL:.+]]: vector<8x16xf32>, %[[SRC:.+]]: memref<?x?x?xf32>,
+// CHECK-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
+// CHECK-SAME:   %[[INDICES:.+]]: vector<8x16xindex>, %[[MASK:.+]]: vector<8x16xi1>) {
+// CHECK:        memref.extract_strided_metadata %[[SRC]]
+// CHECK-COUNT2: arith.muli {{.*}} : index
+// CHECK-COUNT2: arith.addi {{.*}} : index
+// CHECK:        %[[SPLAT:.+]] = vector.broadcast {{.*}}:  index to vector<8x16xindex>
+// CHECK:        %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8x16xindex>
+// CHECK:        %[[BASE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<?x?x?xf32> -> index
+// CHECK:        %[[BASE_I64:.+]] = arith.index_cast %[[BASE]] : index to i64
+// CHECK:        xegpu.store %[[VAL]], %[[BASE_I64]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : vector<8x16xf32>, i64, vector<8x16xindex>, vector<8x16xi1>
+// CHECK:        gpu.return
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @store_dynamic_source2(%vec: vector<8x16xf32>, %source: memref<?x8x16xf32>,
+    %off0: index, %off1: index, %off2: index,
+    %indices: vector<8x16xindex>, %mask: vector<8x16xi1>) {
+  vector.scatter %source[%off0, %off1, %off2][%indices], %mask, %vec
+       : memref<?x8x16xf32>, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32>
+  gpu.return
+}
+// CHECK-LABEL:  @store_dynamic_source2(
+// CHECK-SAME:   %[[VAL:.+]]: vector<8x16xf32>, %[[SRC:.+]]: memref<?x8x16xf32>,
+// CHECK-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
+// CHECK-SAME:   %[[INDICES:.+]]: vector<8x16xindex>, %[[MASK:.+]]: vector<8x16xi1>) {
+// CHECK-NOT:    memref.extract_strided_metadata %[[SRC]]
+// CHECK-COUNT2: arith.muli {{.*}} : index
+// CHECK-COUNT2: arith.addi {{.*}} : index
+// CHECK:        %[[SPLAT:.+]] = vector.broadcast {{.*}}:  index to vector<8x16xindex>
+// CHECK:        %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8x16xindex>
+// CHECK:        %[[BASE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<?x8x16xf32> -> index
+// CHECK:        %[[BASE_I64:.+]] = arith.index_cast %[[BASE]] : index to i64
+// CHECK:        xegpu.store %[[VAL]], %[[BASE_I64]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : vector<8x16xf32>, i64, vector<8x16xindex>, vector<8x16xi1>
+// CHECK:        gpu.return
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @non_unit_inner_stride_1D(
+    %vec: vector<8xf32>, %source: memref<32xf32, strided<[?], offset: ?>>,
+    %off: index, %indices: vector<8xindex>, %mask: vector<8xi1>) {
+  vector.scatter %source[%off][%indices], %mask, %vec
+    : memref<32xf32, strided<[?], offset: ?>>, vector<8xindex>, vector<8xi1>, vector<8xf32>
+  gpu.return
+}
+// CHECK-LABEL:  @non_unit_inner_stride_1D(
+// CHECK-SAME:   %[[VAL:.+]]: vector<8xf32>, %[[SRC:.+]]: memref<32xf32, strided<[?], offset: ?>>,
+// CHECK-SAME:   %[[OFF1:.+]]: index,
+// CHECK-SAME:   %[[INDICES:.+]]: vector<8xindex>, %[[MASK:.+]]: vector<8xi1>) {
+// CHECK:        %[[BB:.+]], %[[M_OFF:.+]], %[[SZ:.+]], %[[STRIDE:.+]] = memref.extract_strided_metadata %[[SRC]]
+// CHECK:        arith.muli %[[OFF1]], %[[STRIDE]] : index
+// CHECK:        arith.addi {{.*}} : index
+// CHECK:        %[[STRD_VEC:.+]] = vector.broadcast %[[STRIDE]] : index to vector<8xindex>
+// CHECK:        %[[STRD_INDICES:.+]] = arith.muli %[[STRD_VEC:.+]], %[[INDICES]] : vector<8xindex>
+// CHECK:        %[[SPLAT:.+]] = vector.broadcast {{.*}}:  index to vector<8xindex>
+// CHECK:        %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[STRD_INDICES]] : vector<8xindex>
+// CHECK:        %[[BASE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<32xf32, strided<[?], offset: ?>> -> index
+// CHECK:        %[[BASE_I64:.+]] = arith.index_cast %[[BASE]] : index to i64
+// CHECK:        xegpu.store %[[VAL]], %[[BASE_I64]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : vector<8xf32>, i64, vector<8xindex>, vector<8xi1>
+// CHECK:        gpu.return
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @non_unit_inner_stride_3D(
+    %vec: vector<8xf32>,
+    %source: memref<4x8x32xf32, strided<[?, 128, 2], offset: ?>>,
+    %off0: index, %off1: index, %off2: index,
+    %indices: vector<8xindex>, %mask: vector<8xi1>) {
+  vector.scatter %source[%off0, %off1, %off2][%indices], %mask, %vec
+    : memref<4x8x32xf32, strided<[?, 128, 2], offset: ?>>,
+      vector<8xindex>, vector<8xi1>, vector<8xf32>
+  gpu.return
+}
+// CHECK-LABEL:  @non_unit_inner_stride_3D(
+// CHECK-SAME:   %[[VAL:.+]]: vector<8xf32>, %[[SRC:.+]]: memref<4x8x32xf32, strided<[?, 128, 2], offset: ?>>,
+// CHECK-SAME:   %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index,
+// CHECK-SAME:   %[[INDICES:.+]]: vector<8xindex>, %[[MASK:.+]]: vector<8xi1>) {
+// CHECK:        %[[BB:.+]], %[[M_OFF:.+]], %[[SIZES:.+]]:3, %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
+// CHECK:        arith.muli %[[OFF0]], %[[STRIDES]]#0 : index
+// CHECK:        arith.addi {{.*}} : index
+// CHECK-COUNT2: arith.muli {{.*}} : index
+// CHECK-COUNT2: arith.addi {{.*}} : index
+// CHECK:        %[[STRD_INDICES:.+]] = arith.muli {{.*}}%[[INDICES]]{{.*}} : vector<8xindex>
+// CHECK:        %[[SPLAT:.+]] = vector.broadcast {{.*}} : index to vector<8xindex>
+// CHECK:        %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[STRD_INDICES]] : vector<8xindex>
+// CHECK:        %[[BASE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<4x8x32xf32, strided<[?, 128, 2], offset: ?>> -> index
+// CHECK:        %[[BASE_I64:.+]] = arith.index_cast %[[BASE]] : index to i64
+// CHECK:        xegpu.store %[[VAL]], %[[BASE_I64]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : vector<8xf32>, i64, vector<8xindex>, vector<8xi1>
+// CHECK:        gpu.return
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @scatter_into_subview(%vals: vector<8xf16>,
+                               %source: memref<4096x4096xf16>,
+                               %memref_off: index, %off1: index, %off2: index,
+                               %indices: vector<8xindex>,
+                               %mask: vector<8xi1>) {
+  %subview = memref.subview %source[%memref_off, %memref_off] [256, 256] [1, 1]
+      : memref<4096x4096xf16>
+        to memref<256x256xf16, strided<[4096, 1], offset: ?>>
+  vector.scatter %subview[%off1, %off2][%indices], %mask, %vals
+      : memref<256x256xf16, strided<[4096, 1], offset: ?>>,
+        vector<8xindex>, vector<8xi1>, vector<8xf16>
+  gpu.return
+}
+// CHECK-LABEL:  @scatter_into_subview(
+// CHECK-SAME:   %[[VALS:.+]]: vector<8xf16>,
+// CHECK-SAME:   %[[SRC:.+]]: memref<4096x4096xf16>,
+// CHECK-SAME:   %[[MEMREF_OFF:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index,
+// CHECK-SAME:   %[[INDICES:.+]]: vector<8xindex>, %[[MASK:.+]]: vector<8xi1>) {
+// CHECK:        %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[MEMREF_OFF]], %[[MEMREF_OFF]]] [256, 256] [1, 1]
+// CHECK:        %[[BB:.+]], %[[OFFSET:.+]],{{.*}},{{.*}} = memref.extract_strided_metadata %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> memref<f16>, index, index, index, index, index
+// CHECK:        arith.muli {{.*}}%[[OFF1]]{{.*}} : index
+// CHECK:        arith.addi %[[OFFSET]]{{.*}} : index
+// CHECK:        %[[BASE_OFF:.+]] = arith.addi {{.*}}%[[OFF2]]{{.*}} : index
+// CHECK:        %[[SPLAT:.+]] = vector.broadcast %[[BASE_OFF]] : index to vector<8xindex>
+// CHECK:        %[[LIN:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex>
+// CHECK:        %[[BASE_IDX:.+]] = memref.extract_aligned_pointer_as_index %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> index
+// CHECK:        %[[BASE_I64:.+]] = arith.index_cast %[[BASE_IDX]] : index to i64
+// CHECK:        xegpu.store %[[VALS]], %[[BASE_I64]]{{\[}}%[[LIN]]{{\]}}, %[[MASK]] : vector<8xf16>, i64, vector<8xindex>, vector<8xi1>
+// CHECK:        gpu.return
+}


        


More information about the Mlir-commits mailing list