[Mlir-commits] [mlir] [MLIR][XeGPU][VectorToXeGPU] Add lowering from vector.gather/scatter to xegpu.load/store (PR #158024)

Dmitry Chigarev llvmlistbot at llvm.org
Thu Sep 11 06:29:34 PDT 2025


https://github.com/dchigarev updated https://github.com/llvm/llvm-project/pull/158024

>From 7b3e740cc54fb8d2a83214450dd204f60d09e692 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Wed, 10 Sep 2025 10:35:05 +0000
Subject: [PATCH 1/2] [MLIR][XeGPU] Add lowering from vector.gather/scatter to
 xegpu.load/store

Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
 .../VectorToXeGPU/VectorToXeGPU.cpp           | 142 +++++++++++++++-
 .../VectorToXeGPU/gather-to-xegpu.mlir        | 160 ++++++++++++++++++
 .../VectorToXeGPU/scatter-to-xegpu.mlir       | 125 ++++++++++++++
 3 files changed, 419 insertions(+), 8 deletions(-)
 create mode 100644 mlir/test/Conversion/VectorToXeGPU/gather-to-xegpu.mlir
 create mode 100644 mlir/test/Conversion/VectorToXeGPU/scatter-to-xegpu.mlir

diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 819c2e5973ffd..52def22ad025e 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -97,6 +97,21 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
   return success();
 }
 
+static LogicalResult gatherScatterPreconditions(PatternRewriter &rewriter,
+                                                Operation *op, Type baseType) {
+  auto srcTy = dyn_cast<MemRefType>(baseType);
+  if (!srcTy)
+    return rewriter.notifyMatchFailure(op, "Expects memref source");
+
+  SmallVector<int64_t> strides;
+  int64_t offset;
+  if (failed(srcTy.getStridesAndOffset(strides, offset)) || strides.back() != 1)
+    return rewriter.notifyMatchFailure(
+        op, "Buffer must be contiguous in the innermost dimension");
+
+  return success();
+}
+
 static xegpu::CreateNdDescOp
 createNdDescriptor(PatternRewriter &rewriter, Location loc,
                    xegpu::TensorDescType descType, TypedValue<MemRefType> src,
@@ -183,11 +198,15 @@ static void adjustStridesForPermutation(AffineMap permMap,
 // Computes memory strides for vector transfer operations, handling both
 // static and dynamic memrefs while applying permutation transformations
 // for XeGPU lowering.
-static SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
+template <
+    typename OpType,
+    typename = std::enable_if_t<llvm::is_one_of<
+        std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
+        vector::GatherOp, vector::ScatterOp>::value>>
+static SmallVector<Value> computeStrides(OpType xferOp,
                                          PatternRewriter &rewriter) {
   SmallVector<Value> strides;
   Value baseMemref = xferOp.getBase();
-  AffineMap permMap = xferOp.getPermutationMap();
   MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
 
   Location loc = xferOp.getLoc();
@@ -222,8 +241,14 @@ static SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
         rewriter, loc, resultTypes, baseMemref);
     strides.append(meta.getStrides().begin(), meta.getStrides().end());
   }
-  // Adjust strides according to the permutation map (e.g., for transpose)
-  adjustStridesForPermutation(permMap, strides);
+
+  if constexpr (llvm::is_one_of<std::decay_t<OpType>, vector::TransferReadOp,
+                                vector::TransferWriteOp>::value) {
+    AffineMap permMap = xferOp.getPermutationMap();
+    // Adjust strides according to the permutation map (e.g., for transpose)
+    adjustStridesForPermutation(permMap, strides);
+  }
+
   return strides;
 }
 
@@ -334,9 +359,45 @@ static Value computeOffsets(VectorTransferOpInterface xferOp,
   return localOffsets;
 }
 
+// Compute the element-wise offsets for vector.gather or vector.scatter ops.
+//
+// This function linearizes the base offsets of the gather/scatter operation
+// and combines them with the per-element indices to produce a final vector of
+// memory offsets.
+template <
+    typename OpType,
+    typename = std::enable_if_t<llvm::is_one_of<
+        std::decay_t<OpType>, vector::GatherOp, vector::ScatterOp>::value>>
+static Value computeOffsets(PatternRewriter &rewriter, OpType gatScatOp,
+                            ArrayRef<Value> strides) {
+  Location loc = gatScatOp.getLoc();
+  SmallVector<Value> offsets = gatScatOp.getOffsets();
+  Value linearOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
+  for (size_t i = 0; i < offsets.size(); ++i) {
+    Value offsetContrib =
+        arith::MulIOp::create(rewriter, loc, offsets[i], strides[i]);
+    linearOffset =
+        arith::AddIOp::create(rewriter, loc, linearOffset, offsetContrib);
+  }
+  Value indices = gatScatOp.getIndices();
+  VectorType vecType = cast<VectorType>(indices.getType());
+
+  Value baseVector =
+      vector::BroadcastOp::create(
+          rewriter, loc,
+          VectorType::get(vecType.getShape(), rewriter.getIndexType()),
+          linearOffset)
+          .getResult();
+  return arith::AddIOp::create(rewriter, loc, baseVector, indices).getResult();
+}
+
 // Collapse memref shape to 1D
-static Value collapseMemrefTo1D(VectorTransferOpInterface xferOp,
-                                PatternRewriter &rewriter) {
+template <
+    typename OpType,
+    typename = std::enable_if_t<llvm::is_one_of<
+        std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
+        vector::GatherOp, vector::ScatterOp>::value>>
+static Value collapseMemrefTo1D(OpType xferOp, PatternRewriter &rewriter) {
   Location loc = xferOp.getLoc();
 
   Value baseMemref = xferOp.getBase();
@@ -546,6 +607,69 @@ struct TransferWriteLowering
   }
 };
 
+struct GatherLowering : public OpRewritePattern<vector::GatherOp> {
+  using OpRewritePattern<vector::GatherOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
+                                PatternRewriter &rewriter) const override {
+    if (failed(gatherScatterPreconditions(rewriter, gatherOp,
+                                          gatherOp.getBase().getType())))
+      return failure();
+
+    Location loc = gatherOp.getLoc();
+    VectorType vectorType = gatherOp.getVectorType();
+
+    SmallVector<Value> strides = computeStrides(gatherOp, rewriter);
+    if (strides.empty())
+      return rewriter.notifyMatchFailure(gatherOp, "Failed to compute strides");
+
+    Value localOffsets = computeOffsets(rewriter, gatherOp, strides);
+    Value flatMemref = collapseMemrefTo1D(gatherOp, rewriter);
+
+    auto xeGatherOp = xegpu::LoadGatherOp::create(
+        rewriter, loc, vectorType, flatMemref, localOffsets, gatherOp.getMask(),
+        /*chunk_size=*/IntegerAttr{},
+        /*l1_hint=*/xegpu::CachePolicyAttr{},
+        /*l2_hint=*/xegpu::CachePolicyAttr{},
+        /*l3_hint=*/xegpu::CachePolicyAttr{});
+
+    auto selectOp =
+        arith::SelectOp::create(rewriter, loc, gatherOp.getMask(),
+                                xeGatherOp.getResult(), gatherOp.getPassThru());
+    rewriter.replaceOp(gatherOp, selectOp.getResult());
+    return success();
+  }
+};
+
+struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> {
+  using OpRewritePattern<vector::ScatterOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ScatterOp scatterOp,
+                                PatternRewriter &rewriter) const override {
+    if (failed(gatherScatterPreconditions(rewriter, scatterOp,
+                                          scatterOp.getBase().getType())))
+      return failure();
+
+    Location loc = scatterOp.getLoc();
+    SmallVector<Value> strides = computeStrides(scatterOp, rewriter);
+    if (strides.empty())
+      return rewriter.notifyMatchFailure(scatterOp,
+                                         "Failed to compute strides");
+
+    Value localOffsets = computeOffsets(rewriter, scatterOp, strides);
+    Value flatMemref = collapseMemrefTo1D(scatterOp, rewriter);
+
+    xegpu::StoreScatterOp::create(rewriter, loc, scatterOp.getValueToStore(),
+                                  flatMemref, localOffsets, scatterOp.getMask(),
+                                  /*chunk_size=*/IntegerAttr{},
+                                  /*l1_hint=*/xegpu::CachePolicyAttr{},
+                                  /*l2_hint=*/xegpu::CachePolicyAttr{},
+                                  /*l3_hint=*/xegpu::CachePolicyAttr{});
+    rewriter.eraseOp(scatterOp);
+    return success();
+  }
+};
+
 struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
   using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
 
@@ -661,6 +785,8 @@ struct ConvertVectorToXeGPUPass
 
 void mlir::populateVectorToXeGPUConversionPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<TransferReadLowering, TransferWriteLowering, LoadLowering,
-               StoreLowering, ContractionLowering>(patterns.getContext());
+  patterns
+      .add<TransferReadLowering, TransferWriteLowering, LoadLowering,
+           ScatterLowering, GatherLowering, StoreLowering, ContractionLowering>(
+          patterns.getContext());
 }
diff --git a/mlir/test/Conversion/VectorToXeGPU/gather-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/gather-to-xegpu.mlir
new file mode 100644
index 0000000000000..6b249ba37becd
--- /dev/null
+++ b/mlir/test/Conversion/VectorToXeGPU/gather-to-xegpu.mlir
@@ -0,0 +1,160 @@
+// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s
+
+gpu.module @xevm_module {
+gpu.func @load_1D_vector(%source: memref<8x16x32xf32>,
+     %off1: index, %off2: index, %off3: index,
+     %indices: vector<8xindex>, %mask: vector<8xi1>,
+     %pass_thru: vector<8xf32>) -> vector<8xf32> {
+  %0 = vector.gather %source[%off1, %off2, %off3][%indices], %mask,
+       %pass_thru : memref<8x16x32xf32>, vector<8xindex>, vector<8xi1>, vector<8xf32> into vector<8xf32>
+  gpu.return %0 : vector<8xf32>
+}
+// CHECK-LABEL:  @load_1D_vector(
+// CHECK-SAME:   %[[SRC:.+]]: memref<8x16x32xf32>,
+// CHECK-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
+// CHECK-SAME:   %[[INDICES:.+]]: vector<8xindex>
+// CHECK-SAME:   %[[MASK:.+]]: vector<8xi1>
+// CHECK-SAME:   %[[PASS_THRU:.+]]: vector<8xf32>) -> vector<8xf32> {
+// CHECK-COUNT2: arith.muli {{.*}} : index
+// CHECK-COUNT2: arith.addi {{.*}} : index
+// CHECK:        %[[SPLAT:.+]] = vector.broadcast {{.*}}:  index to vector<8xindex>
+// CHECK:        %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex>
+// CHECK:        %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
+// CHECK:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : memref<4096xf32>, vector<8xindex>, vector<8xi1> -> vector<8xf32>
+// CHECK:        %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8xi1>, vector<8xf32>
+// CHECK:        gpu.return %[[RES]] : vector<8xf32>
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @load_2D_memref(%source: memref<8x32xf32>,
+     %off1: index, %off2: index,
+     %indices: vector<8xindex>, %mask: vector<8xi1>,
+     %pass_thru: vector<8xf32>) -> vector<8xf32> {
+  %0 = vector.gather %source[%off1, %off2][%indices], %mask,
+       %pass_thru : memref<8x32xf32>, vector<8xindex>, vector<8xi1>, vector<8xf32> into vector<8xf32>
+  gpu.return %0 : vector<8xf32>
+}
+// CHECK-LABEL:  @load_2D_memref(
+// CHECK-SAME:   %[[SRC:.+]]: memref<8x32xf32>,
+// CHECK-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// CHECK-SAME:   %[[INDICES:.+]]: vector<8xindex>
+// CHECK-SAME:   %[[MASK:.+]]: vector<8xi1>
+// CHECK-SAME:   %[[PASS_THRU:.+]]: vector<8xf32>) -> vector<8xf32> {
+// CHECK-COUNT1: arith.muli {{.*}} : index
+// CHECK-COUNT1: arith.addi {{.*}} : index
+// CHECK:        %[[SPLAT:.+]] = vector.broadcast {{.*}}:  index to vector<8xindex>
+// CHECK:        %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex>
+// CHECK:        %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1]{{\]}} : memref<8x32xf32> into memref<256xf32>
+// CHECK:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : memref<256xf32>, vector<8xindex>, vector<8xi1> -> vector<8xf32>
+// CHECK:        %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8xi1>, vector<8xf32>
+// CHECK:        gpu.return %[[RES]] : vector<8xf32>
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @load_2D_vector(%source: memref<8x16x32xf32>,
+    %off0: index, %off1: index, %off2: index,
+    %indices: vector<8x16xindex>, %mask: vector<8x16xi1>,
+    %pass_thru: vector<8x16xf32>) -> vector<8x16xf32> {
+  %0 = vector.gather %source[%off0, %off1, %off2][%indices], %mask,
+       %pass_thru : memref<8x16x32xf32>, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32> into vector<8x16xf32>
+  gpu.return %0 : vector<8x16xf32>
+}
+// CHECK-LABEL:  @load_2D_vector(
+// CHECK-SAME:   %[[SRC:.+]]: memref<8x16x32xf32>,
+// CHECK-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
+// CHECK-SAME:   %[[INDICES:.+]]: vector<8x16xindex>
+// CHECK-SAME:   %[[MASK:.+]]: vector<8x16xi1>
+// CHECK-SAME:   %[[PASS_THRU:.+]]: vector<8x16xf32>) -> vector<8x16xf32> {
+// CHECK-COUNT2: arith.muli {{.*}} : index
+// CHECK-COUNT2: arith.addi {{.*}} : index
+// CHECK:        %[[SPLAT:.+]] = vector.broadcast {{.*}}:  index to vector<8x16xindex>
+// CHECK:        %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8x16xindex>
+// CHECK:        %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
+// CHECK:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : memref<4096xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+// CHECK:        %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8x16xi1>, vector<8x16xf32>
+// CHECK:        gpu.return %[[RES]] : vector<8x16xf32>
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
+    %off0: index, %off1: index, %off2: index,
+    %indices: vector<8x16xindex>, %mask: vector<8x16xi1>,
+    %pass_thru: vector<8x16xf32>) -> vector<8x16xf32> {
+  %0 = vector.gather %source[%off0, %off1, %off2][%indices], %mask,
+       %pass_thru : memref<?x?x?xf32>, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32> into vector<8x16xf32>
+  gpu.return %0 : vector<8x16xf32>
+}
+// CHECK-LABEL:  @load_dynamic_source(
+// CHECK-SAME:   %[[SRC:.+]]: memref<?x?x?xf32>,
+// CHECK-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
+// CHECK-SAME:   %[[INDICES:.+]]: vector<8x16xindex>
+// CHECK-SAME:   %[[MASK:.+]]: vector<8x16xi1>
+// CHECK-SAME:   %[[PASS_THRU:.+]]: vector<8x16xf32>) -> vector<8x16xf32> {
+// CHECK:        memref.extract_strided_metadata %[[SRC]]
+// CHECK-COUNT2: arith.muli {{.*}} : index
+// CHECK-COUNT2: arith.addi {{.*}} : index
+// CHECK:        %[[SPLAT:.+]] = vector.broadcast {{.*}}:  index to vector<8x16xindex>
+// CHECK:        %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8x16xindex>
+// CHECK:        %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<?x?x?xf32> into memref<?xf32>
+// CHECK:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+// CHECK:        %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8x16xi1>, vector<8x16xf32>
+// CHECK:        gpu.return %[[RES]] : vector<8x16xf32>
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
+    %off0: index, %off1: index, %off2: index,
+    %indices: vector<8x16xindex>, %mask: vector<8x16xi1>,
+    %pass_thru: vector<8x16xf32>) -> vector<8x16xf32> {
+  %0 = vector.gather %source[%off0, %off1, %off2][%indices], %mask,
+       %pass_thru : memref<?x8x16xf32>, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32> into vector<8x16xf32>
+  gpu.return %0 : vector<8x16xf32>
+}
+// CHECK-LABEL:  @load_dynamic_source2(
+// CHECK-SAME:   %[[SRC:.+]]: memref<?x8x16xf32>,
+// CHECK-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
+// CHECK-SAME:   %[[INDICES:.+]]: vector<8x16xindex>
+// CHECK-SAME:   %[[MASK:.+]]: vector<8x16xi1>
+// CHECK-SAME:   %[[PASS_THRU:.+]]: vector<8x16xf32>) -> vector<8x16xf32> {
+// CHECK-NOT:    memref.extract_strided_metadata %[[SRC]]
+// CHECK-COUNT2: arith.muli {{.*}} : index
+// CHECK-COUNT2: arith.addi {{.*}} : index
+// CHECK:        %[[SPLAT:.+]] = vector.broadcast {{.*}}:  index to vector<8x16xindex>
+// CHECK:        %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8x16xindex>
+// CHECK:        %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<?x8x16xf32> into memref<?xf32>
+// CHECK:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+// CHECK:        %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8x16xi1>, vector<8x16xf32>
+// CHECK:        gpu.return %[[RES]] : vector<8x16xf32>
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @no_load_tensor(%source: tensor<32x64xf32>,
+    %off: index, %indices: vector<8x16xindex>,
+    %mask: vector<8x16xi1>, %pass_thru: vector<8x16xf32>) -> vector<8x16xf32> {
+  %0 = vector.gather %source[%off, %off][%indices], %mask,
+       %pass_thru : tensor<32x64xf32>, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32> into vector<8x16xf32>
+  gpu.return %0 : vector<8x16xf32>
+}
+// CHECK-LABEL:  @no_load_tensor(
+// CHECK:        vector.gather
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @no_load_non_unit_inner_stride(
+    %source: memref<32xf32, strided<[?], offset: ?>>,
+    %off: index, %indices: vector<8xindex>, %mask: vector<8xi1>,
+    %pass_thru: vector<8xf32>) -> vector<8xf32> {
+  %0 = vector.gather %source[%off][%indices], %mask, %pass_thru
+    : memref<32xf32, strided<[?], offset: ?>>, vector<8xindex>, vector<8xi1>, vector<8xf32> into vector<8xf32>
+  gpu.return %0 : vector<8xf32>
+}
+// CHECK-LABEL:  @no_load_non_unit_inner_stride(
+// CHECK:        vector.gather
+}
+
diff --git a/mlir/test/Conversion/VectorToXeGPU/scatter-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/scatter-to-xegpu.mlir
new file mode 100644
index 0000000000000..4448ed5b6cda9
--- /dev/null
+++ b/mlir/test/Conversion/VectorToXeGPU/scatter-to-xegpu.mlir
@@ -0,0 +1,125 @@
+// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s
+
+gpu.module @xevm_module {
+gpu.func @store_1D_vector(%vec: vector<8xf32>, %source: memref<8x16x32xf32>,
+     %off1: index, %off2: index, %off3: index,
+     %indices: vector<8xindex>, %mask: vector<8xi1>) {
+  vector.scatter %source[%off1, %off2, %off3][%indices], %mask, %vec
+       : memref<8x16x32xf32>, vector<8xindex>, vector<8xi1>, vector<8xf32>
+  gpu.return
+}
+// CHECK-LABEL:  @store_1D_vector(
+// CHECK-SAME:   %[[VAL:.+]]: vector<8xf32>, %[[SRC:.+]]: memref<8x16x32xf32>,
+// CHECK-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
+// CHECK-SAME:   %[[INDICES:.+]]: vector<8xindex>, %[[MASK:.+]]: vector<8xi1>) {
+// CHECK-COUNT2: arith.muli {{.*}} : index
+// CHECK-COUNT2: arith.addi {{.*}} : index
+// CHECK:        %[[SPLAT:.+]] = vector.broadcast {{.*}}:  index to vector<8xindex>
+// CHECK:        %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex>
+// CHECK:        %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
+// CHECK:        xegpu.store %[[VAL]], %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : vector<8xf32>, memref<4096xf32>, vector<8xindex>, vector<8xi1>
+// CHECK:        gpu.return
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @store_2D_memref(%vec: vector<8xf32>, %source: memref<8x32xf32>,
+     %off1: index, %off2: index,
+     %indices: vector<8xindex>, %mask: vector<8xi1>) {
+  vector.scatter %source[%off1, %off2][%indices], %mask, %vec
+       : memref<8x32xf32>, vector<8xindex>, vector<8xi1>, vector<8xf32>
+  gpu.return
+}
+// CHECK-LABEL:  @store_2D_memref(
+// CHECK-SAME:   %[[VAL:.+]]: vector<8xf32>, %[[SRC:.+]]: memref<8x32xf32>,
+// CHECK-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// CHECK-SAME:   %[[INDICES:.+]]: vector<8xindex>, %[[MASK:.+]]: vector<8xi1>) {
+// CHECK-COUNT1: arith.muli {{.*}} : index
+// CHECK-COUNT1: arith.addi {{.*}} : index
+// CHECK:        %[[SPLAT:.+]] = vector.broadcast {{.*}}:  index to vector<8xindex>
+// CHECK:        %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex>
+// CHECK:        %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1]{{\]}} : memref<8x32xf32> into memref<256xf32>
+// CHECK:        xegpu.store %[[VAL]], %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : vector<8xf32>, memref<256xf32>, vector<8xindex>, vector<8xi1>
+// CHECK:        gpu.return
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @store_2D_vector(%vec: vector<8x16xf32>, %source: memref<8x16x32xf32>,
+    %off0: index, %off1: index, %off2: index,
+    %indices: vector<8x16xindex>, %mask: vector<8x16xi1>) {
+  vector.scatter %source[%off0, %off1, %off2][%indices], %mask, %vec
+       : memref<8x16x32xf32>, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32>
+  gpu.return
+}
+// CHECK-LABEL:  @store_2D_vector(
+// CHECK-SAME:   %[[VAL:.+]]: vector<8x16xf32>, %[[SRC:.+]]: memref<8x16x32xf32>,
+// CHECK-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
+// CHECK-SAME:   %[[INDICES:.+]]: vector<8x16xindex>, %[[MASK:.+]]: vector<8x16xi1>) {
+// CHECK-COUNT2: arith.muli {{.*}} : index
+// CHECK-COUNT2: arith.addi {{.*}} : index
+// CHECK:        %[[SPLAT:.+]] = vector.broadcast {{.*}}:  index to vector<8x16xindex>
+// CHECK:        %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8x16xindex>
+// CHECK:        %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
+// CHECK:        xegpu.store %[[VAL]], %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : vector<8x16xf32>, memref<4096xf32>, vector<8x16xindex>, vector<8x16xi1>
+// CHECK:        gpu.return
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @store_dynamic_source(%vec: vector<8x16xf32>, %source: memref<?x?x?xf32>,
+    %off0: index, %off1: index, %off2: index,
+    %indices: vector<8x16xindex>, %mask: vector<8x16xi1>) {
+  vector.scatter %source[%off0, %off1, %off2][%indices], %mask, %vec
+       : memref<?x?x?xf32>, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32>
+  gpu.return
+}
+// CHECK-LABEL:  @store_dynamic_source(
+// CHECK-SAME:   %[[VAL:.+]]: vector<8x16xf32>, %[[SRC:.+]]: memref<?x?x?xf32>,
+// CHECK-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
+// CHECK-SAME:   %[[INDICES:.+]]: vector<8x16xindex>, %[[MASK:.+]]: vector<8x16xi1>) {
+// CHECK:        memref.extract_strided_metadata %[[SRC]]
+// CHECK-COUNT2: arith.muli {{.*}} : index
+// CHECK-COUNT2: arith.addi {{.*}} : index
+// CHECK:        %[[SPLAT:.+]] = vector.broadcast {{.*}}:  index to vector<8x16xindex>
+// CHECK:        %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8x16xindex>
+// CHECK:        %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<?x?x?xf32> into memref<?xf32>
+// CHECK:        xegpu.store %[[VAL]], %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : vector<8x16xf32>, memref<?xf32>, vector<8x16xindex>, vector<8x16xi1>
+// CHECK:        gpu.return
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @store_dynamic_source2(%vec: vector<8x16xf32>, %source: memref<?x8x16xf32>,
+    %off0: index, %off1: index, %off2: index,
+    %indices: vector<8x16xindex>, %mask: vector<8x16xi1>) {
+  vector.scatter %source[%off0, %off1, %off2][%indices], %mask, %vec
+       : memref<?x8x16xf32>, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32>
+  gpu.return
+}
+// CHECK-LABEL:  @store_dynamic_source2(
+// CHECK-SAME:   %[[VAL:.+]]: vector<8x16xf32>, %[[SRC:.+]]: memref<?x8x16xf32>,
+// CHECK-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
+// CHECK-SAME:   %[[INDICES:.+]]: vector<8x16xindex>, %[[MASK:.+]]: vector<8x16xi1>) {
+// CHECK-NOT:    memref.extract_strided_metadata %[[SRC]]
+// CHECK-COUNT2: arith.muli {{.*}} : index
+// CHECK-COUNT2: arith.addi {{.*}} : index
+// CHECK:        %[[SPLAT:.+]] = vector.broadcast {{.*}}:  index to vector<8x16xindex>
+// CHECK:        %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8x16xindex>
+// CHECK:        %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<?x8x16xf32> into memref<?xf32>
+// CHECK:        xegpu.store %[[VAL]], %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : vector<8x16xf32>, memref<?xf32>, vector<8x16xindex>, vector<8x16xi1>
+// CHECK:        gpu.return
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @no_store_non_unit_inner_stride(
+    %vec: vector<8xf32>, %source: memref<32xf32, strided<[?], offset: ?>>,
+    %off: index, %indices: vector<8xindex>, %mask: vector<8xi1>) {
+  vector.scatter %source[%off][%indices], %mask, %vec
+    : memref<32xf32, strided<[?], offset: ?>>, vector<8xindex>, vector<8xi1>, vector<8xf32>
+  gpu.return
+}
+// CHECK-LABEL:  @no_store_non_unit_inner_stride(
+// CHECK:        vector.scatter
+}

>From a7501133b4ffcb008fe6ab10a15ce2bc236e73f8 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Thu, 11 Sep 2025 13:29:17 +0000
Subject: [PATCH 2/2] Add alignment handling

Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
 .../VectorToXeGPU/VectorToXeGPU.cpp           | 12 +++++++++
 .../VectorToXeGPU/gather-to-xegpu.mlir        | 27 +++++++++++++++++++
 .../VectorToXeGPU/scatter-to-xegpu.mlir       | 23 ++++++++++++++++
 3 files changed, 62 insertions(+)

diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 52def22ad025e..dc70d3f2a4946 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -626,6 +626,12 @@ struct GatherLowering : public OpRewritePattern<vector::GatherOp> {
     Value localOffsets = computeOffsets(rewriter, gatherOp, strides);
     Value flatMemref = collapseMemrefTo1D(gatherOp, rewriter);
 
+    if (auto alignment = gatherOp.getAlignment()) {
+      flatMemref = memref::AssumeAlignmentOp::create(rewriter, loc, flatMemref,
+                                                     alignment.value())
+                       .getResult();
+    }
+
     auto xeGatherOp = xegpu::LoadGatherOp::create(
         rewriter, loc, vectorType, flatMemref, localOffsets, gatherOp.getMask(),
         /*chunk_size=*/IntegerAttr{},
@@ -659,6 +665,12 @@ struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> {
     Value localOffsets = computeOffsets(rewriter, scatterOp, strides);
     Value flatMemref = collapseMemrefTo1D(scatterOp, rewriter);
 
+    if (auto alignment = scatterOp.getAlignment()) {
+      flatMemref = memref::AssumeAlignmentOp::create(rewriter, loc, flatMemref,
+                                                     alignment.value())
+                       .getResult();
+    }
+
     xegpu::StoreScatterOp::create(rewriter, loc, scatterOp.getValueToStore(),
                                   flatMemref, localOffsets, scatterOp.getMask(),
                                   /*chunk_size=*/IntegerAttr{},
diff --git a/mlir/test/Conversion/VectorToXeGPU/gather-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/gather-to-xegpu.mlir
index 6b249ba37becd..4a1988c45ad43 100644
--- a/mlir/test/Conversion/VectorToXeGPU/gather-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/gather-to-xegpu.mlir
@@ -158,3 +158,30 @@ gpu.func @no_load_non_unit_inner_stride(
 // CHECK:        vector.gather
 }
 
+// -----
+gpu.module @xevm_module {
+gpu.func @load_1D_aligned(%source: memref<8x16x32xf32>,
+     %off1: index, %off2: index, %off3: index,
+     %indices: vector<8xindex>, %mask: vector<8xi1>,
+     %pass_thru: vector<8xf32>) -> vector<8xf32> {
+  %0 = vector.gather %source[%off1, %off2, %off3][%indices], %mask,
+       %pass_thru {alignment = 256} : memref<8x16x32xf32>, vector<8xindex>, vector<8xi1>, vector<8xf32> into vector<8xf32>
+  gpu.return %0 : vector<8xf32>
+}
+// CHECK-LABEL:  @load_1D_aligned(
+// CHECK-SAME:   %[[SRC:.+]]: memref<8x16x32xf32>,
+// CHECK-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
+// CHECK-SAME:   %[[INDICES:.+]]: vector<8xindex>
+// CHECK-SAME:   %[[MASK:.+]]: vector<8xi1>
+// CHECK-SAME:   %[[PASS_THRU:.+]]: vector<8xf32>) -> vector<8xf32> {
+// CHECK-COUNT2: arith.muli {{.*}} : index
+// CHECK-COUNT2: arith.addi {{.*}} : index
+// CHECK:        %[[SPLAT:.+]] = vector.broadcast {{.*}}:  index to vector<8xindex>
+// CHECK:        %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex>
+// CHECK:        %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
+// CHECK:        %[[COLLAPSE_ALIGN:.+]] = memref.assume_alignment %[[COLLAPSE]], 256 : memref<4096xf32>
+// CHECK:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE_ALIGN]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : memref<4096xf32>, vector<8xindex>, vector<8xi1> -> vector<8xf32>
+// CHECK:        %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8xi1>, vector<8xf32>
+// CHECK:        gpu.return %[[RES]] : vector<8xf32>
+}
+
diff --git a/mlir/test/Conversion/VectorToXeGPU/scatter-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/scatter-to-xegpu.mlir
index 4448ed5b6cda9..73c4e735d01aa 100644
--- a/mlir/test/Conversion/VectorToXeGPU/scatter-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/scatter-to-xegpu.mlir
@@ -123,3 +123,26 @@ gpu.func @no_store_non_unit_inner_stride(
 // CHECK-LABEL:  @no_store_non_unit_inner_stride(
 // CHECK:        vector.scatter
 }
+
+// -----
+gpu.module @xevm_module {
+gpu.func @store_1D_aligned(%vec: vector<8xf32>, %source: memref<8x16x32xf32>,
+     %off1: index, %off2: index, %off3: index,
+     %indices: vector<8xindex>, %mask: vector<8xi1>) {
+  vector.scatter %source[%off1, %off2, %off3][%indices], %mask, %vec {alignment = 256}
+       : memref<8x16x32xf32>, vector<8xindex>, vector<8xi1>, vector<8xf32>
+  gpu.return
+}
+// CHECK-LABEL:  @store_1D_aligned(
+// CHECK-SAME:   %[[VAL:.+]]: vector<8xf32>, %[[SRC:.+]]: memref<8x16x32xf32>,
+// CHECK-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
+// CHECK-SAME:   %[[INDICES:.+]]: vector<8xindex>, %[[MASK:.+]]: vector<8xi1>) {
+// CHECK-COUNT2: arith.muli {{.*}} : index
+// CHECK-COUNT2: arith.addi {{.*}} : index
+// CHECK:        %[[SPLAT:.+]] = vector.broadcast {{.*}}:  index to vector<8xindex>
+// CHECK:        %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex>
+// CHECK:        %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
+// CHECK:        %[[COLLAPSE_ALIGN:.+]] = memref.assume_alignment %[[COLLAPSE]], 256 : memref<4096xf32>
+// CHECK:        xegpu.store %[[VAL]], %[[COLLAPSE_ALIGN]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : vector<8xf32>, memref<4096xf32>, vector<8xindex>, vector<8xi1>
+// CHECK:        gpu.return
+}



More information about the Mlir-commits mailing list