[Mlir-commits] [mlir] [MLIR][XeGPU][VectorToXeGPU] Add lowering from vector.gather/scatter to xegpu.load/store (PR #158024)
Dmitry Chigarev
llvmlistbot at llvm.org
Fri Sep 12 04:02:18 PDT 2025
https://github.com/dchigarev updated https://github.com/llvm/llvm-project/pull/158024
>From d97a3c215131a62f408388d0efa50a1a27688e6e Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Wed, 10 Sep 2025 10:35:05 +0000
Subject: [PATCH 1/3] [MLIR][XeGPU] Add lowering from vector.gather/scatter to
xegpu.load/store
Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
.../VectorToXeGPU/VectorToXeGPU.cpp | 143 +++++++++++++++-
.../VectorToXeGPU/gather-to-xegpu.mlir | 160 ++++++++++++++++++
.../VectorToXeGPU/scatter-to-xegpu.mlir | 125 ++++++++++++++
3 files changed, 420 insertions(+), 8 deletions(-)
create mode 100644 mlir/test/Conversion/VectorToXeGPU/gather-to-xegpu.mlir
create mode 100644 mlir/test/Conversion/VectorToXeGPU/scatter-to-xegpu.mlir
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 852c322cc6467..2ad4f7bf2d074 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -97,6 +97,21 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
return success();
}
+static LogicalResult gatherScatterPreconditions(PatternRewriter &rewriter,
+ Operation *op, Type baseType) {
+ auto srcTy = dyn_cast<MemRefType>(baseType);
+ if (!srcTy)
+ return rewriter.notifyMatchFailure(op, "Expects memref source");
+
+ SmallVector<int64_t> strides;
+ int64_t offset;
+ if (failed(srcTy.getStridesAndOffset(strides, offset)) || strides.back() != 1)
+ return rewriter.notifyMatchFailure(
+ op, "Buffer must be contiguous in the innermost dimension");
+
+ return success();
+}
+
static xegpu::CreateNdDescOp
createNdDescriptor(PatternRewriter &rewriter, Location loc,
xegpu::TensorDescType descType, TypedValue<MemRefType> src,
@@ -183,11 +198,15 @@ static void adjustStridesForPermutation(AffineMap permMap,
// Computes memory strides and a memref offset for vector transfer operations,
// handling both static and dynamic memrefs while applying permutation
// transformations for XeGPU lowering.
+template <
+ typename OpType,
+ typename = std::enable_if_t<llvm::is_one_of<
+ std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
+ vector::GatherOp, vector::ScatterOp>::value>>
static std::pair<SmallVector<Value>, Value>
-computeMemrefMeta(VectorTransferOpInterface xferOp, PatternRewriter &rewriter) {
+computeMemrefMeta(OpType xferOp, PatternRewriter &rewriter) {
SmallVector<Value> strides;
Value baseMemref = xferOp.getBase();
- AffineMap permMap = xferOp.getPermutationMap();
MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
Location loc = xferOp.getLoc();
@@ -232,9 +251,15 @@ computeMemrefMeta(VectorTransferOpInterface xferOp, PatternRewriter &rewriter) {
if (!offsetVal)
offsetVal = meta.getOffset();
}
- // Adjust strides according to the permutation map (e.g., for transpose)
- adjustStridesForPermutation(permMap, strides);
- return {strides, offsetVal};
+
+ if constexpr (llvm::is_one_of<std::decay_t<OpType>, vector::TransferReadOp,
+ vector::TransferWriteOp>::value) {
+ AffineMap permMap = xferOp.getPermutationMap();
+ // Adjust strides according to the permutation map (e.g., for transpose)
+ adjustStridesForPermutation(permMap, strides);
+ }
+
+ return strides;
}
// This function compute the vectors of localOffsets for scattered load/stores.
@@ -339,8 +364,45 @@ static Value computeOffsets(VectorTransferOpInterface xferOp,
return localOffsets;
}
+// Compute the element-wise offsets for vector.gather or vector.scatter ops.
+//
+// This function linearizes the base offsets of the gather/scatter operation
+// and combines them with the per-element indices to produce a final vector of
+// memory offsets.
+template <
+ typename OpType,
+ typename = std::enable_if_t<llvm::is_one_of<
+ std::decay_t<OpType>, vector::GatherOp, vector::ScatterOp>::value>>
+static Value computeOffsets(PatternRewriter &rewriter, OpType gatScatOp,
+ ArrayRef<Value> strides) {
+ Location loc = gatScatOp.getLoc();
+ SmallVector<Value> offsets = gatScatOp.getOffsets();
+ Value linearOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ for (size_t i = 0; i < offsets.size(); ++i) {
+ Value offsetContrib =
+ arith::MulIOp::create(rewriter, loc, offsets[i], strides[i]);
+ linearOffset =
+ arith::AddIOp::create(rewriter, loc, linearOffset, offsetContrib);
+ }
+ Value indices = gatScatOp.getIndices();
+ VectorType vecType = cast<VectorType>(indices.getType());
+
+ Value baseVector =
+ vector::BroadcastOp::create(
+ rewriter, loc,
+ VectorType::get(vecType.getShape(), rewriter.getIndexType()),
+ linearOffset)
+ .getResult();
+ return arith::AddIOp::create(rewriter, loc, baseVector, indices).getResult();
+}
+
+template <
+ typename OpType,
+ typename = std::enable_if_t<llvm::is_one_of<
+ std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
+ vector::GatherOp, vector::ScatterOp>::value>>
// Convert memref to i64 base pointer
-static Value memrefToIndexPtr(VectorTransferOpInterface xferOp,
+static Value memrefToIndexPtr(OpType xferOp,
PatternRewriter &rewriter) {
Location loc = xferOp.getLoc();
auto indexPtr = memref::ExtractAlignedPointerAsIndexOp::create(
@@ -539,6 +601,69 @@ struct TransferWriteLowering
}
};
+struct GatherLowering : public OpRewritePattern<vector::GatherOp> {
+ using OpRewritePattern<vector::GatherOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
+ PatternRewriter &rewriter) const override {
+ if (failed(gatherScatterPreconditions(rewriter, gatherOp,
+ gatherOp.getBase().getType())))
+ return failure();
+
+ Location loc = gatherOp.getLoc();
+ VectorType vectorType = gatherOp.getVectorType();
+
+ SmallVector<Value> strides = computeStrides(gatherOp, rewriter);
+ if (strides.empty())
+ return rewriter.notifyMatchFailure(gatherOp, "Failed to compute strides");
+
+ Value localOffsets = computeOffsets(rewriter, gatherOp, strides);
+ Value flatMemref = collapseMemrefTo1D(gatherOp, rewriter);
+
+ auto xeGatherOp = xegpu::LoadGatherOp::create(
+ rewriter, loc, vectorType, flatMemref, localOffsets, gatherOp.getMask(),
+ /*chunk_size=*/IntegerAttr{},
+ /*l1_hint=*/xegpu::CachePolicyAttr{},
+ /*l2_hint=*/xegpu::CachePolicyAttr{},
+ /*l3_hint=*/xegpu::CachePolicyAttr{});
+
+ auto selectOp =
+ arith::SelectOp::create(rewriter, loc, gatherOp.getMask(),
+ xeGatherOp.getResult(), gatherOp.getPassThru());
+ rewriter.replaceOp(gatherOp, selectOp.getResult());
+ return success();
+ }
+};
+
+struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> {
+ using OpRewritePattern<vector::ScatterOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ScatterOp scatterOp,
+ PatternRewriter &rewriter) const override {
+ if (failed(gatherScatterPreconditions(rewriter, scatterOp,
+ scatterOp.getBase().getType())))
+ return failure();
+
+ Location loc = scatterOp.getLoc();
+ SmallVector<Value> strides = computeStrides(scatterOp, rewriter);
+ if (strides.empty())
+ return rewriter.notifyMatchFailure(scatterOp,
+ "Failed to compute strides");
+
+ Value localOffsets = computeOffsets(rewriter, scatterOp, strides);
+ Value flatMemref = collapseMemrefTo1D(scatterOp, rewriter);
+
+ xegpu::StoreScatterOp::create(rewriter, loc, scatterOp.getValueToStore(),
+ flatMemref, localOffsets, scatterOp.getMask(),
+ /*chunk_size=*/IntegerAttr{},
+ /*l1_hint=*/xegpu::CachePolicyAttr{},
+ /*l2_hint=*/xegpu::CachePolicyAttr{},
+ /*l3_hint=*/xegpu::CachePolicyAttr{});
+ rewriter.eraseOp(scatterOp);
+ return success();
+ }
+};
+
struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
@@ -654,6 +779,8 @@ struct ConvertVectorToXeGPUPass
void mlir::populateVectorToXeGPUConversionPatterns(
RewritePatternSet &patterns) {
- patterns.add<TransferReadLowering, TransferWriteLowering, LoadLowering,
- StoreLowering, ContractionLowering>(patterns.getContext());
+ patterns
+ .add<TransferReadLowering, TransferWriteLowering, LoadLowering,
+ ScatterLowering, GatherLowering, StoreLowering, ContractionLowering>(
+ patterns.getContext());
}
diff --git a/mlir/test/Conversion/VectorToXeGPU/gather-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/gather-to-xegpu.mlir
new file mode 100644
index 0000000000000..6b249ba37becd
--- /dev/null
+++ b/mlir/test/Conversion/VectorToXeGPU/gather-to-xegpu.mlir
@@ -0,0 +1,160 @@
+// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s
+
+gpu.module @xevm_module {
+gpu.func @load_1D_vector(%source: memref<8x16x32xf32>,
+ %off1: index, %off2: index, %off3: index,
+ %indices: vector<8xindex>, %mask: vector<8xi1>,
+ %pass_thru: vector<8xf32>) -> vector<8xf32> {
+ %0 = vector.gather %source[%off1, %off2, %off3][%indices], %mask,
+ %pass_thru : memref<8x16x32xf32>, vector<8xindex>, vector<8xi1>, vector<8xf32> into vector<8xf32>
+ gpu.return %0 : vector<8xf32>
+}
+// CHECK-LABEL: @load_1D_vector(
+// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
+// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
+// CHECK-SAME: %[[INDICES:.+]]: vector<8xindex>
+// CHECK-SAME: %[[MASK:.+]]: vector<8xi1>
+// CHECK-SAME: %[[PASS_THRU:.+]]: vector<8xf32>) -> vector<8xf32> {
+// CHECK-COUNT2: arith.muli {{.*}} : index
+// CHECK-COUNT2: arith.addi {{.*}} : index
+// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
+// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex>
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
+// CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : memref<4096xf32>, vector<8xindex>, vector<8xi1> -> vector<8xf32>
+// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8xi1>, vector<8xf32>
+// CHECK: gpu.return %[[RES]] : vector<8xf32>
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @load_2D_memref(%source: memref<8x32xf32>,
+ %off1: index, %off2: index,
+ %indices: vector<8xindex>, %mask: vector<8xi1>,
+ %pass_thru: vector<8xf32>) -> vector<8xf32> {
+ %0 = vector.gather %source[%off1, %off2][%indices], %mask,
+ %pass_thru : memref<8x32xf32>, vector<8xindex>, vector<8xi1>, vector<8xf32> into vector<8xf32>
+ gpu.return %0 : vector<8xf32>
+}
+// CHECK-LABEL: @load_2D_memref(
+// CHECK-SAME: %[[SRC:.+]]: memref<8x32xf32>,
+// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// CHECK-SAME: %[[INDICES:.+]]: vector<8xindex>
+// CHECK-SAME: %[[MASK:.+]]: vector<8xi1>
+// CHECK-SAME: %[[PASS_THRU:.+]]: vector<8xf32>) -> vector<8xf32> {
+// CHECK-COUNT1: arith.muli {{.*}} : index
+// CHECK-COUNT1: arith.addi {{.*}} : index
+// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
+// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex>
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1]{{\]}} : memref<8x32xf32> into memref<256xf32>
+// CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : memref<256xf32>, vector<8xindex>, vector<8xi1> -> vector<8xf32>
+// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8xi1>, vector<8xf32>
+// CHECK: gpu.return %[[RES]] : vector<8xf32>
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @load_2D_vector(%source: memref<8x16x32xf32>,
+ %off0: index, %off1: index, %off2: index,
+ %indices: vector<8x16xindex>, %mask: vector<8x16xi1>,
+ %pass_thru: vector<8x16xf32>) -> vector<8x16xf32> {
+ %0 = vector.gather %source[%off0, %off1, %off2][%indices], %mask,
+ %pass_thru : memref<8x16x32xf32>, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32> into vector<8x16xf32>
+ gpu.return %0 : vector<8x16xf32>
+}
+// CHECK-LABEL: @load_2D_vector(
+// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
+// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
+// CHECK-SAME: %[[INDICES:.+]]: vector<8x16xindex>
+// CHECK-SAME: %[[MASK:.+]]: vector<8x16xi1>
+// CHECK-SAME: %[[PASS_THRU:.+]]: vector<8x16xf32>) -> vector<8x16xf32> {
+// CHECK-COUNT2: arith.muli {{.*}} : index
+// CHECK-COUNT2: arith.addi {{.*}} : index
+// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex>
+// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8x16xindex>
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
+// CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : memref<4096xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8x16xi1>, vector<8x16xf32>
+// CHECK: gpu.return %[[RES]] : vector<8x16xf32>
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
+ %off0: index, %off1: index, %off2: index,
+ %indices: vector<8x16xindex>, %mask: vector<8x16xi1>,
+ %pass_thru: vector<8x16xf32>) -> vector<8x16xf32> {
+ %0 = vector.gather %source[%off0, %off1, %off2][%indices], %mask,
+ %pass_thru : memref<?x?x?xf32>, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32> into vector<8x16xf32>
+ gpu.return %0 : vector<8x16xf32>
+}
+// CHECK-LABEL: @load_dynamic_source(
+// CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
+// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
+// CHECK-SAME: %[[INDICES:.+]]: vector<8x16xindex>
+// CHECK-SAME: %[[MASK:.+]]: vector<8x16xi1>
+// CHECK-SAME: %[[PASS_THRU:.+]]: vector<8x16xf32>) -> vector<8x16xf32> {
+// CHECK: memref.extract_strided_metadata %[[SRC]]
+// CHECK-COUNT2: arith.muli {{.*}} : index
+// CHECK-COUNT2: arith.addi {{.*}} : index
+// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex>
+// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8x16xindex>
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<?x?x?xf32> into memref<?xf32>
+// CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8x16xi1>, vector<8x16xf32>
+// CHECK: gpu.return %[[RES]] : vector<8x16xf32>
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
+ %off0: index, %off1: index, %off2: index,
+ %indices: vector<8x16xindex>, %mask: vector<8x16xi1>,
+ %pass_thru: vector<8x16xf32>) -> vector<8x16xf32> {
+ %0 = vector.gather %source[%off0, %off1, %off2][%indices], %mask,
+ %pass_thru : memref<?x8x16xf32>, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32> into vector<8x16xf32>
+ gpu.return %0 : vector<8x16xf32>
+}
+// CHECK-LABEL: @load_dynamic_source2(
+// CHECK-SAME: %[[SRC:.+]]: memref<?x8x16xf32>,
+// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
+// CHECK-SAME: %[[INDICES:.+]]: vector<8x16xindex>
+// CHECK-SAME: %[[MASK:.+]]: vector<8x16xi1>
+// CHECK-SAME: %[[PASS_THRU:.+]]: vector<8x16xf32>) -> vector<8x16xf32> {
+// CHECK-NOT: memref.extract_strided_metadata %[[SRC]]
+// CHECK-COUNT2: arith.muli {{.*}} : index
+// CHECK-COUNT2: arith.addi {{.*}} : index
+// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex>
+// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8x16xindex>
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<?x8x16xf32> into memref<?xf32>
+// CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8x16xi1>, vector<8x16xf32>
+// CHECK: gpu.return %[[RES]] : vector<8x16xf32>
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @no_load_tensor(%source: tensor<32x64xf32>,
+ %off: index, %indices: vector<8x16xindex>,
+ %mask: vector<8x16xi1>, %pass_thru: vector<8x16xf32>) -> vector<8x16xf32> {
+ %0 = vector.gather %source[%off, %off][%indices], %mask,
+ %pass_thru : tensor<32x64xf32>, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32> into vector<8x16xf32>
+ gpu.return %0 : vector<8x16xf32>
+}
+// CHECK-LABEL: @no_load_tensor(
+// CHECK: vector.gather
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @no_load_non_unit_inner_stride(
+ %source: memref<32xf32, strided<[?], offset: ?>>,
+ %off: index, %indices: vector<8xindex>, %mask: vector<8xi1>,
+ %pass_thru: vector<8xf32>) -> vector<8xf32> {
+ %0 = vector.gather %source[%off][%indices], %mask, %pass_thru
+ : memref<32xf32, strided<[?], offset: ?>>, vector<8xindex>, vector<8xi1>, vector<8xf32> into vector<8xf32>
+ gpu.return %0 : vector<8xf32>
+}
+// CHECK-LABEL: @no_load_non_unit_inner_stride(
+// CHECK: vector.gather
+}
+
diff --git a/mlir/test/Conversion/VectorToXeGPU/scatter-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/scatter-to-xegpu.mlir
new file mode 100644
index 0000000000000..4448ed5b6cda9
--- /dev/null
+++ b/mlir/test/Conversion/VectorToXeGPU/scatter-to-xegpu.mlir
@@ -0,0 +1,125 @@
+// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s
+
+gpu.module @xevm_module {
+gpu.func @store_1D_vector(%vec: vector<8xf32>, %source: memref<8x16x32xf32>,
+ %off1: index, %off2: index, %off3: index,
+ %indices: vector<8xindex>, %mask: vector<8xi1>) {
+ vector.scatter %source[%off1, %off2, %off3][%indices], %mask, %vec
+ : memref<8x16x32xf32>, vector<8xindex>, vector<8xi1>, vector<8xf32>
+ gpu.return
+}
+// CHECK-LABEL: @store_1D_vector(
+// CHECK-SAME: %[[VAL:.+]]: vector<8xf32>, %[[SRC:.+]]: memref<8x16x32xf32>,
+// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
+// CHECK-SAME: %[[INDICES:.+]]: vector<8xindex>, %[[MASK:.+]]: vector<8xi1>) {
+// CHECK-COUNT2: arith.muli {{.*}} : index
+// CHECK-COUNT2: arith.addi {{.*}} : index
+// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
+// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex>
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
+// CHECK: xegpu.store %[[VAL]], %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : vector<8xf32>, memref<4096xf32>, vector<8xindex>, vector<8xi1>
+// CHECK: gpu.return
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @store_2D_memref(%vec: vector<8xf32>, %source: memref<8x32xf32>,
+ %off1: index, %off2: index,
+ %indices: vector<8xindex>, %mask: vector<8xi1>) {
+ vector.scatter %source[%off1, %off2][%indices], %mask, %vec
+ : memref<8x32xf32>, vector<8xindex>, vector<8xi1>, vector<8xf32>
+ gpu.return
+}
+// CHECK-LABEL: @store_2D_memref(
+// CHECK-SAME: %[[VAL:.+]]: vector<8xf32>, %[[SRC:.+]]: memref<8x32xf32>,
+// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// CHECK-SAME: %[[INDICES:.+]]: vector<8xindex>, %[[MASK:.+]]: vector<8xi1>) {
+// CHECK-COUNT1: arith.muli {{.*}} : index
+// CHECK-COUNT1: arith.addi {{.*}} : index
+// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
+// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex>
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1]{{\]}} : memref<8x32xf32> into memref<256xf32>
+// CHECK: xegpu.store %[[VAL]], %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : vector<8xf32>, memref<256xf32>, vector<8xindex>, vector<8xi1>
+// CHECK: gpu.return
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @store_2D_vector(%vec: vector<8x16xf32>, %source: memref<8x16x32xf32>,
+ %off0: index, %off1: index, %off2: index,
+ %indices: vector<8x16xindex>, %mask: vector<8x16xi1>) {
+ vector.scatter %source[%off0, %off1, %off2][%indices], %mask, %vec
+ : memref<8x16x32xf32>, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32>
+ gpu.return
+}
+// CHECK-LABEL: @store_2D_vector(
+// CHECK-SAME: %[[VAL:.+]]: vector<8x16xf32>, %[[SRC:.+]]: memref<8x16x32xf32>,
+// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
+// CHECK-SAME: %[[INDICES:.+]]: vector<8x16xindex>, %[[MASK:.+]]: vector<8x16xi1>) {
+// CHECK-COUNT2: arith.muli {{.*}} : index
+// CHECK-COUNT2: arith.addi {{.*}} : index
+// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex>
+// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8x16xindex>
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
+// CHECK: xegpu.store %[[VAL]], %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : vector<8x16xf32>, memref<4096xf32>, vector<8x16xindex>, vector<8x16xi1>
+// CHECK: gpu.return
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @store_dynamic_source(%vec: vector<8x16xf32>, %source: memref<?x?x?xf32>,
+ %off0: index, %off1: index, %off2: index,
+ %indices: vector<8x16xindex>, %mask: vector<8x16xi1>) {
+ vector.scatter %source[%off0, %off1, %off2][%indices], %mask, %vec
+ : memref<?x?x?xf32>, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32>
+ gpu.return
+}
+// CHECK-LABEL: @store_dynamic_source(
+// CHECK-SAME: %[[VAL:.+]]: vector<8x16xf32>, %[[SRC:.+]]: memref<?x?x?xf32>,
+// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
+// CHECK-SAME: %[[INDICES:.+]]: vector<8x16xindex>, %[[MASK:.+]]: vector<8x16xi1>) {
+// CHECK: memref.extract_strided_metadata %[[SRC]]
+// CHECK-COUNT2: arith.muli {{.*}} : index
+// CHECK-COUNT2: arith.addi {{.*}} : index
+// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex>
+// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8x16xindex>
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<?x?x?xf32> into memref<?xf32>
+// CHECK: xegpu.store %[[VAL]], %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : vector<8x16xf32>, memref<?xf32>, vector<8x16xindex>, vector<8x16xi1>
+// CHECK: gpu.return
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @store_dynamic_source2(%vec: vector<8x16xf32>, %source: memref<?x8x16xf32>,
+ %off0: index, %off1: index, %off2: index,
+ %indices: vector<8x16xindex>, %mask: vector<8x16xi1>) {
+ vector.scatter %source[%off0, %off1, %off2][%indices], %mask, %vec
+ : memref<?x8x16xf32>, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32>
+ gpu.return
+}
+// CHECK-LABEL: @store_dynamic_source2(
+// CHECK-SAME: %[[VAL:.+]]: vector<8x16xf32>, %[[SRC:.+]]: memref<?x8x16xf32>,
+// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
+// CHECK-SAME: %[[INDICES:.+]]: vector<8x16xindex>, %[[MASK:.+]]: vector<8x16xi1>) {
+// CHECK-NOT: memref.extract_strided_metadata %[[SRC]]
+// CHECK-COUNT2: arith.muli {{.*}} : index
+// CHECK-COUNT2: arith.addi {{.*}} : index
+// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex>
+// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8x16xindex>
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<?x8x16xf32> into memref<?xf32>
+// CHECK: xegpu.store %[[VAL]], %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : vector<8x16xf32>, memref<?xf32>, vector<8x16xindex>, vector<8x16xi1>
+// CHECK: gpu.return
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @no_store_non_unit_inner_stride(
+ %vec: vector<8xf32>, %source: memref<32xf32, strided<[?], offset: ?>>,
+ %off: index, %indices: vector<8xindex>, %mask: vector<8xi1>) {
+ vector.scatter %source[%off][%indices], %mask, %vec
+ : memref<32xf32, strided<[?], offset: ?>>, vector<8xindex>, vector<8xi1>, vector<8xf32>
+ gpu.return
+}
+// CHECK-LABEL: @no_store_non_unit_inner_stride(
+// CHECK: vector.scatter
+}
>From 4d2c2844befcf52ac378a159cca98100731f1d63 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Thu, 11 Sep 2025 13:29:17 +0000
Subject: [PATCH 2/3] Add alignment handling
Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
.../VectorToXeGPU/VectorToXeGPU.cpp | 12 +++++++++
.../VectorToXeGPU/gather-to-xegpu.mlir | 27 +++++++++++++++++++
.../VectorToXeGPU/scatter-to-xegpu.mlir | 23 ++++++++++++++++
3 files changed, 62 insertions(+)
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 2ad4f7bf2d074..89d383ab4b7e6 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -620,6 +620,12 @@ struct GatherLowering : public OpRewritePattern<vector::GatherOp> {
Value localOffsets = computeOffsets(rewriter, gatherOp, strides);
Value flatMemref = collapseMemrefTo1D(gatherOp, rewriter);
+ if (auto alignment = gatherOp.getAlignment()) {
+ flatMemref = memref::AssumeAlignmentOp::create(rewriter, loc, flatMemref,
+ alignment.value())
+ .getResult();
+ }
+
auto xeGatherOp = xegpu::LoadGatherOp::create(
rewriter, loc, vectorType, flatMemref, localOffsets, gatherOp.getMask(),
/*chunk_size=*/IntegerAttr{},
@@ -653,6 +659,12 @@ struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> {
Value localOffsets = computeOffsets(rewriter, scatterOp, strides);
Value flatMemref = collapseMemrefTo1D(scatterOp, rewriter);
+ if (auto alignment = scatterOp.getAlignment()) {
+ flatMemref = memref::AssumeAlignmentOp::create(rewriter, loc, flatMemref,
+ alignment.value())
+ .getResult();
+ }
+
xegpu::StoreScatterOp::create(rewriter, loc, scatterOp.getValueToStore(),
flatMemref, localOffsets, scatterOp.getMask(),
/*chunk_size=*/IntegerAttr{},
diff --git a/mlir/test/Conversion/VectorToXeGPU/gather-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/gather-to-xegpu.mlir
index 6b249ba37becd..4a1988c45ad43 100644
--- a/mlir/test/Conversion/VectorToXeGPU/gather-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/gather-to-xegpu.mlir
@@ -158,3 +158,30 @@ gpu.func @no_load_non_unit_inner_stride(
// CHECK: vector.gather
}
+// -----
+gpu.module @xevm_module {
+gpu.func @load_1D_aligned(%source: memref<8x16x32xf32>,
+ %off1: index, %off2: index, %off3: index,
+ %indices: vector<8xindex>, %mask: vector<8xi1>,
+ %pass_thru: vector<8xf32>) -> vector<8xf32> {
+ %0 = vector.gather %source[%off1, %off2, %off3][%indices], %mask,
+ %pass_thru {alignment = 256} : memref<8x16x32xf32>, vector<8xindex>, vector<8xi1>, vector<8xf32> into vector<8xf32>
+ gpu.return %0 : vector<8xf32>
+}
+// CHECK-LABEL: @load_1D_aligned(
+// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
+// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
+// CHECK-SAME: %[[INDICES:.+]]: vector<8xindex>
+// CHECK-SAME: %[[MASK:.+]]: vector<8xi1>
+// CHECK-SAME: %[[PASS_THRU:.+]]: vector<8xf32>) -> vector<8xf32> {
+// CHECK-COUNT2: arith.muli {{.*}} : index
+// CHECK-COUNT2: arith.addi {{.*}} : index
+// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
+// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex>
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
+// CHECK: %[[COLLAPSE_ALIGN:.+]] = memref.assume_alignment %[[COLLAPSE]], 256 : memref<4096xf32>
+// CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_ALIGN]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : memref<4096xf32>, vector<8xindex>, vector<8xi1> -> vector<8xf32>
+// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8xi1>, vector<8xf32>
+// CHECK: gpu.return %[[RES]] : vector<8xf32>
+}
+
diff --git a/mlir/test/Conversion/VectorToXeGPU/scatter-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/scatter-to-xegpu.mlir
index 4448ed5b6cda9..73c4e735d01aa 100644
--- a/mlir/test/Conversion/VectorToXeGPU/scatter-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/scatter-to-xegpu.mlir
@@ -123,3 +123,26 @@ gpu.func @no_store_non_unit_inner_stride(
// CHECK-LABEL: @no_store_non_unit_inner_stride(
// CHECK: vector.scatter
}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @store_1D_aligned(%vec: vector<8xf32>, %source: memref<8x16x32xf32>,
+ %off1: index, %off2: index, %off3: index,
+ %indices: vector<8xindex>, %mask: vector<8xi1>) {
+ vector.scatter %source[%off1, %off2, %off3][%indices], %mask, %vec {alignment = 256}
+ : memref<8x16x32xf32>, vector<8xindex>, vector<8xi1>, vector<8xf32>
+ gpu.return
+}
+// CHECK-LABEL: @store_1D_aligned(
+// CHECK-SAME: %[[VAL:.+]]: vector<8xf32>, %[[SRC:.+]]: memref<8x16x32xf32>,
+// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
+// CHECK-SAME: %[[INDICES:.+]]: vector<8xindex>, %[[MASK:.+]]: vector<8xi1>) {
+// CHECK-COUNT2: arith.muli {{.*}} : index
+// CHECK-COUNT2: arith.addi {{.*}} : index
+// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
+// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex>
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
+// CHECK: %[[COLLAPSE_ALIGN:.+]] = memref.assume_alignment %[[COLLAPSE]], 256 : memref<4096xf32>
+// CHECK: xegpu.store %[[VAL]], %[[COLLAPSE_ALIGN]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : vector<8xf32>, memref<4096xf32>, vector<8xindex>, vector<8xi1>
+// CHECK: gpu.return
+}
>From 62c5c38064ad4077139012892e56fd6220b103b8 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Fri, 12 Sep 2025 11:01:50 +0000
Subject: [PATCH 3/3] Align lowering with new utils behavior
Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
.../VectorToXeGPU/VectorToXeGPU.cpp | 47 ++++-----
.../VectorToXeGPU/gather-to-xegpu.mlir | 96 +++++++++----------
.../VectorToXeGPU/scatter-to-xegpu.mlir | 65 ++++++++-----
3 files changed, 107 insertions(+), 101 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 89d383ab4b7e6..eebaceba488b4 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -97,6 +97,9 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
return success();
}
+// Common preconditions for the lowering of vector.gather and vector.scatter:
+// 1. Source is a memref.
+// 2. The innermost dimension of the memref is contiguous (stride == 1)
static LogicalResult gatherScatterPreconditions(PatternRewriter &rewriter,
Operation *op, Type baseType) {
auto srcTy = dyn_cast<MemRefType>(baseType);
@@ -259,7 +262,7 @@ computeMemrefMeta(OpType xferOp, PatternRewriter &rewriter) {
adjustStridesForPermutation(permMap, strides);
}
- return strides;
+ return {strides, offsetVal};
}
// This function compute the vectors of localOffsets for scattered load/stores.
@@ -374,15 +377,14 @@ template <
typename = std::enable_if_t<llvm::is_one_of<
std::decay_t<OpType>, vector::GatherOp, vector::ScatterOp>::value>>
static Value computeOffsets(PatternRewriter &rewriter, OpType gatScatOp,
- ArrayRef<Value> strides) {
+ ArrayRef<Value> strides, Value baseOffset) {
Location loc = gatScatOp.getLoc();
SmallVector<Value> offsets = gatScatOp.getOffsets();
- Value linearOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
for (size_t i = 0; i < offsets.size(); ++i) {
Value offsetContrib =
arith::MulIOp::create(rewriter, loc, offsets[i], strides[i]);
- linearOffset =
- arith::AddIOp::create(rewriter, loc, linearOffset, offsetContrib);
+ baseOffset =
+ arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
}
Value indices = gatScatOp.getIndices();
VectorType vecType = cast<VectorType>(indices.getType());
@@ -391,7 +393,7 @@ static Value computeOffsets(PatternRewriter &rewriter, OpType gatScatOp,
vector::BroadcastOp::create(
rewriter, loc,
VectorType::get(vecType.getShape(), rewriter.getIndexType()),
- linearOffset)
+ baseOffset)
.getResult();
return arith::AddIOp::create(rewriter, loc, baseVector, indices).getResult();
}
@@ -402,8 +404,7 @@ template <
std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
vector::GatherOp, vector::ScatterOp>::value>>
// Convert memref to i64 base pointer
-static Value memrefToIndexPtr(OpType xferOp,
- PatternRewriter &rewriter) {
+static Value memrefToIndexPtr(OpType xferOp, PatternRewriter &rewriter) {
Location loc = xferOp.getLoc();
auto indexPtr = memref::ExtractAlignedPointerAsIndexOp::create(
rewriter, loc, xferOp.getBase())
@@ -613,18 +614,13 @@ struct GatherLowering : public OpRewritePattern<vector::GatherOp> {
Location loc = gatherOp.getLoc();
VectorType vectorType = gatherOp.getVectorType();
- SmallVector<Value> strides = computeStrides(gatherOp, rewriter);
- if (strides.empty())
+ auto meta = computeMemrefMeta(gatherOp, rewriter);
+ if (meta.first.empty())
return rewriter.notifyMatchFailure(gatherOp, "Failed to compute strides");
- Value localOffsets = computeOffsets(rewriter, gatherOp, strides);
- Value flatMemref = collapseMemrefTo1D(gatherOp, rewriter);
-
- if (auto alignment = gatherOp.getAlignment()) {
- flatMemref = memref::AssumeAlignmentOp::create(rewriter, loc, flatMemref,
- alignment.value())
- .getResult();
- }
+ Value localOffsets =
+ computeOffsets(rewriter, gatherOp, meta.first, meta.second);
+ Value flatMemref = memrefToIndexPtr(gatherOp, rewriter);
auto xeGatherOp = xegpu::LoadGatherOp::create(
rewriter, loc, vectorType, flatMemref, localOffsets, gatherOp.getMask(),
@@ -651,19 +647,14 @@ struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> {
return failure();
Location loc = scatterOp.getLoc();
- SmallVector<Value> strides = computeStrides(scatterOp, rewriter);
- if (strides.empty())
+ auto meta = computeMemrefMeta(scatterOp, rewriter);
+ if (meta.first.empty())
return rewriter.notifyMatchFailure(scatterOp,
"Failed to compute strides");
- Value localOffsets = computeOffsets(rewriter, scatterOp, strides);
- Value flatMemref = collapseMemrefTo1D(scatterOp, rewriter);
-
- if (auto alignment = scatterOp.getAlignment()) {
- flatMemref = memref::AssumeAlignmentOp::create(rewriter, loc, flatMemref,
- alignment.value())
- .getResult();
- }
+ Value localOffsets =
+ computeOffsets(rewriter, scatterOp, meta.first, meta.second);
+ Value flatMemref = memrefToIndexPtr(scatterOp, rewriter);
xegpu::StoreScatterOp::create(rewriter, loc, scatterOp.getValueToStore(),
flatMemref, localOffsets, scatterOp.getMask(),
diff --git a/mlir/test/Conversion/VectorToXeGPU/gather-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/gather-to-xegpu.mlir
index 4a1988c45ad43..8eb9a40f5ae53 100644
--- a/mlir/test/Conversion/VectorToXeGPU/gather-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/gather-to-xegpu.mlir
@@ -19,8 +19,9 @@ gpu.func @load_1D_vector(%source: memref<8x16x32xf32>,
// CHECK-COUNT2: arith.addi {{.*}} : index
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex>
-// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
-// CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : memref<4096xf32>, vector<8xindex>, vector<8xi1> -> vector<8xf32>
+// CHECK: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index
+// CHECK: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf32>
// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8xi1>, vector<8xf32>
// CHECK: gpu.return %[[RES]] : vector<8xf32>
}
@@ -45,8 +46,9 @@ gpu.func @load_2D_memref(%source: memref<8x32xf32>,
// CHECK-COUNT1: arith.addi {{.*}} : index
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex>
-// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1]{{\]}} : memref<8x32xf32> into memref<256xf32>
-// CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : memref<256xf32>, vector<8xindex>, vector<8xi1> -> vector<8xf32>
+// CHECK: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x32xf32> -> index
+// CHECK: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf32>
// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8xi1>, vector<8xf32>
// CHECK: gpu.return %[[RES]] : vector<8xf32>
}
@@ -71,8 +73,9 @@ gpu.func @load_2D_vector(%source: memref<8x16x32xf32>,
// CHECK-COUNT2: arith.addi {{.*}} : index
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex>
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8x16xindex>
-// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
-// CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : memref<4096xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+// CHECK: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index
+// CHECK: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8x16xi1>, vector<8x16xf32>
// CHECK: gpu.return %[[RES]] : vector<8x16xf32>
}
@@ -98,8 +101,9 @@ gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
// CHECK-COUNT2: arith.addi {{.*}} : index
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex>
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8x16xindex>
-// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<?x?x?xf32> into memref<?xf32>
-// CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+// CHECK: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<?x?x?xf32> -> index
+// CHECK: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8x16xi1>, vector<8x16xf32>
// CHECK: gpu.return %[[RES]] : vector<8x16xf32>
}
@@ -125,8 +129,9 @@ gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
// CHECK-COUNT2: arith.addi {{.*}} : index
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex>
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8x16xindex>
-// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<?x8x16xf32> into memref<?xf32>
-// CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+// CHECK: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<?x8x16xf32> -> index
+// CHECK: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8x16xi1>, vector<8x16xf32>
// CHECK: gpu.return %[[RES]] : vector<8x16xf32>
}
@@ -146,42 +151,37 @@ gpu.func @no_load_tensor(%source: tensor<32x64xf32>,
// -----
gpu.module @xevm_module {
-gpu.func @no_load_non_unit_inner_stride(
- %source: memref<32xf32, strided<[?], offset: ?>>,
- %off: index, %indices: vector<8xindex>, %mask: vector<8xi1>,
- %pass_thru: vector<8xf32>) -> vector<8xf32> {
- %0 = vector.gather %source[%off][%indices], %mask, %pass_thru
- : memref<32xf32, strided<[?], offset: ?>>, vector<8xindex>, vector<8xi1>, vector<8xf32> into vector<8xf32>
- gpu.return %0 : vector<8xf32>
-}
-// CHECK-LABEL: @no_load_non_unit_inner_stride(
-// CHECK: vector.gather
+gpu.func @gather_from_subview(%source: memref<4096x4096xf16>,
+ %off1: index, %off2: index,
+ %indices: vector<8xindex>,
+ %mask: vector<8xi1>,
+ %pass_thru: vector<8xf16>) -> vector<8xf16> {
+ %subview = memref.subview %source[%off1, %off2] [256, 256] [1, 1]
+ : memref<4096x4096xf16>
+ to memref<256x256xf16, strided<[4096, 1], offset: ?>>
+ %0 = vector.gather %subview[%off1, %off2][%indices], %mask, %pass_thru
+ : memref<256x256xf16, strided<[4096, 1], offset: ?>>,
+ vector<8xindex>, vector<8xi1>, vector<8xf16>
+ into vector<8xf16>
+ gpu.return %0 : vector<8xf16>
+}
+// CHECK-LABEL: @gather_from_subview(
+// CHECK-SAME: %[[SRC:.+]]: memref<4096x4096xf16>,
+// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index,
+// CHECK-SAME: %[[INDICES:.+]]: vector<8xindex>,
+// CHECK-SAME: %[[MASK:.+]]: vector<8xi1>,
+// CHECK-SAME: %[[PASS:.+]]: vector<8xf16>) -> vector<8xf16> {
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1]
+// CHECK: %[[BB:.+]], %[[OFFSET:.+]],{{.*}},{{.*}} = memref.extract_strided_metadata %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> memref<f16>, index, index, index, index, index
+// CHECK: arith.muli {{.*}} : index
+// CHECK: arith.addi %[[OFFSET]]{{.*}} : index
+// CHECK: %[[BASE_OFF:.+]] = arith.addi {{.*}} : index
+// CHECK: %[[SPLAT:.+]] = vector.broadcast %[[BASE_OFF]] : index to vector<8xindex>
+// CHECK: %[[LIN:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex>
+// CHECK: %[[BASE_IDX:.+]] = memref.extract_aligned_pointer_as_index %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> index
+// CHECK: %[[BASE_I64:.+]] = arith.index_cast %[[BASE_IDX]] : index to i64
+// CHECK: %[[VEC:.+]] = xegpu.load %[[BASE_I64]]{{\[}}%[[LIN]]{{\]}}, %[[MASK]]
+// CHECK-SAME: : i64, vector<8xindex>, vector<8xi1> -> vector<8xf16>
+// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS]] : vector<8xi1>, vector<8xf16>
+// CHECK: gpu.return %[[RES]] : vector<8xf16>
}
-
-// -----
-gpu.module @xevm_module {
-gpu.func @load_1D_aligned(%source: memref<8x16x32xf32>,
- %off1: index, %off2: index, %off3: index,
- %indices: vector<8xindex>, %mask: vector<8xi1>,
- %pass_thru: vector<8xf32>) -> vector<8xf32> {
- %0 = vector.gather %source[%off1, %off2, %off3][%indices], %mask,
- %pass_thru {alignment = 256} : memref<8x16x32xf32>, vector<8xindex>, vector<8xi1>, vector<8xf32> into vector<8xf32>
- gpu.return %0 : vector<8xf32>
-}
-// CHECK-LABEL: @load_1D_aligned(
-// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
-// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
-// CHECK-SAME: %[[INDICES:.+]]: vector<8xindex>
-// CHECK-SAME: %[[MASK:.+]]: vector<8xi1>
-// CHECK-SAME: %[[PASS_THRU:.+]]: vector<8xf32>) -> vector<8xf32> {
-// CHECK-COUNT2: arith.muli {{.*}} : index
-// CHECK-COUNT2: arith.addi {{.*}} : index
-// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
-// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex>
-// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
-// CHECK: %[[COLLAPSE_ALIGN:.+]] = memref.assume_alignment %[[COLLAPSE]], 256 : memref<4096xf32>
-// CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_ALIGN]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : memref<4096xf32>, vector<8xindex>, vector<8xi1> -> vector<8xf32>
-// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8xi1>, vector<8xf32>
-// CHECK: gpu.return %[[RES]] : vector<8xf32>
-}
-
diff --git a/mlir/test/Conversion/VectorToXeGPU/scatter-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/scatter-to-xegpu.mlir
index 73c4e735d01aa..ea6a34a437962 100644
--- a/mlir/test/Conversion/VectorToXeGPU/scatter-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/scatter-to-xegpu.mlir
@@ -16,8 +16,9 @@ gpu.func @store_1D_vector(%vec: vector<8xf32>, %source: memref<8x16x32xf32>,
// CHECK-COUNT2: arith.addi {{.*}} : index
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex>
-// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
-// CHECK: xegpu.store %[[VAL]], %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : vector<8xf32>, memref<4096xf32>, vector<8xindex>, vector<8xi1>
+// CHECK: %[[BASE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index
+// CHECK: %[[BASE_I64:.+]] = arith.index_cast %[[BASE]] : index to i64
+// CHECK: xegpu.store %[[VAL]], %[[BASE_I64]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : vector<8xf32>, i64, vector<8xindex>, vector<8xi1>
// CHECK: gpu.return
}
@@ -38,8 +39,9 @@ gpu.func @store_2D_memref(%vec: vector<8xf32>, %source: memref<8x32xf32>,
// CHECK-COUNT1: arith.addi {{.*}} : index
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex>
-// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1]{{\]}} : memref<8x32xf32> into memref<256xf32>
-// CHECK: xegpu.store %[[VAL]], %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : vector<8xf32>, memref<256xf32>, vector<8xindex>, vector<8xi1>
+// CHECK: %[[BASE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x32xf32> -> index
+// CHECK: %[[BASE_I64:.+]] = arith.index_cast %[[BASE]] : index to i64
+// CHECK: xegpu.store %[[VAL]], %[[BASE_I64]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : vector<8xf32>, i64, vector<8xindex>, vector<8xi1>
// CHECK: gpu.return
}
@@ -60,8 +62,9 @@ gpu.func @store_2D_vector(%vec: vector<8x16xf32>, %source: memref<8x16x32xf32>,
// CHECK-COUNT2: arith.addi {{.*}} : index
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex>
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8x16xindex>
-// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
-// CHECK: xegpu.store %[[VAL]], %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : vector<8x16xf32>, memref<4096xf32>, vector<8x16xindex>, vector<8x16xi1>
+// CHECK: %[[BASE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index
+// CHECK: %[[BASE_I64:.+]] = arith.index_cast %[[BASE]] : index to i64
+// CHECK: xegpu.store %[[VAL]], %[[BASE_I64]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : vector<8x16xf32>, i64, vector<8x16xindex>, vector<8x16xi1>
// CHECK: gpu.return
}
@@ -83,8 +86,9 @@ gpu.func @store_dynamic_source(%vec: vector<8x16xf32>, %source: memref<?x?x?xf32
// CHECK-COUNT2: arith.addi {{.*}} : index
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex>
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8x16xindex>
-// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<?x?x?xf32> into memref<?xf32>
-// CHECK: xegpu.store %[[VAL]], %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : vector<8x16xf32>, memref<?xf32>, vector<8x16xindex>, vector<8x16xi1>
+// CHECK: %[[BASE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<?x?x?xf32> -> index
+// CHECK: %[[BASE_I64:.+]] = arith.index_cast %[[BASE]] : index to i64
+// CHECK: xegpu.store %[[VAL]], %[[BASE_I64]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : vector<8x16xf32>, i64, vector<8x16xindex>, vector<8x16xi1>
// CHECK: gpu.return
}
@@ -106,8 +110,9 @@ gpu.func @store_dynamic_source2(%vec: vector<8x16xf32>, %source: memref<?x8x16xf
// CHECK-COUNT2: arith.addi {{.*}} : index
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex>
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8x16xindex>
-// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<?x8x16xf32> into memref<?xf32>
-// CHECK: xegpu.store %[[VAL]], %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : vector<8x16xf32>, memref<?xf32>, vector<8x16xindex>, vector<8x16xi1>
+// CHECK: %[[BASE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<?x8x16xf32> -> index
+// CHECK: %[[BASE_I64:.+]] = arith.index_cast %[[BASE]] : index to i64
+// CHECK: xegpu.store %[[VAL]], %[[BASE_I64]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : vector<8x16xf32>, i64, vector<8x16xindex>, vector<8x16xi1>
// CHECK: gpu.return
}
@@ -126,23 +131,33 @@ gpu.func @no_store_non_unit_inner_stride(
// -----
gpu.module @xevm_module {
-gpu.func @store_1D_aligned(%vec: vector<8xf32>, %source: memref<8x16x32xf32>,
- %off1: index, %off2: index, %off3: index,
- %indices: vector<8xindex>, %mask: vector<8xi1>) {
- vector.scatter %source[%off1, %off2, %off3][%indices], %mask, %vec {alignment = 256}
- : memref<8x16x32xf32>, vector<8xindex>, vector<8xi1>, vector<8xf32>
+gpu.func @scatter_into_subview(%vals: vector<8xf16>,
+ %source: memref<4096x4096xf16>,
+ %off1: index, %off2: index,
+ %indices: vector<8xindex>,
+ %mask: vector<8xi1>) {
+ %subview = memref.subview %source[%off1, %off2] [256, 256] [1, 1]
+ : memref<4096x4096xf16>
+ to memref<256x256xf16, strided<[4096, 1], offset: ?>>
+ vector.scatter %subview[%off1, %off2][%indices], %mask, %vals
+ : memref<256x256xf16, strided<[4096, 1], offset: ?>>,
+ vector<8xindex>, vector<8xi1>, vector<8xf16>
gpu.return
}
-// CHECK-LABEL: @store_1D_aligned(
-// CHECK-SAME: %[[VAL:.+]]: vector<8xf32>, %[[SRC:.+]]: memref<8x16x32xf32>,
-// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
+// CHECK-LABEL: @scatter_into_subview(
+// CHECK-SAME: %[[VALS:.+]]: vector<8xf16>,
+// CHECK-SAME: %[[SRC:.+]]: memref<4096x4096xf16>,
+// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index,
// CHECK-SAME: %[[INDICES:.+]]: vector<8xindex>, %[[MASK:.+]]: vector<8xi1>) {
-// CHECK-COUNT2: arith.muli {{.*}} : index
-// CHECK-COUNT2: arith.addi {{.*}} : index
-// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
-// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex>
-// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
-// CHECK: %[[COLLAPSE_ALIGN:.+]] = memref.assume_alignment %[[COLLAPSE]], 256 : memref<4096xf32>
-// CHECK: xegpu.store %[[VAL]], %[[COLLAPSE_ALIGN]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : vector<8xf32>, memref<4096xf32>, vector<8xindex>, vector<8xi1>
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1]
+// CHECK: %[[BB:.+]], %[[OFFSET:.+]],{{.*}},{{.*}} = memref.extract_strided_metadata %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> memref<f16>, index, index, index, index, index
+// CHECK: arith.muli {{.*}} : index
+// CHECK: arith.addi %[[OFFSET]]{{.*}} : index
+// CHECK: %[[BASE_OFF:.+]] = arith.addi {{.*}} : index
+// CHECK: %[[SPLAT:.+]] = vector.broadcast %[[BASE_OFF]] : index to vector<8xindex>
+// CHECK: %[[LIN:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex>
+// CHECK: %[[BASE_IDX:.+]] = memref.extract_aligned_pointer_as_index %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> index
+// CHECK: %[[BASE_I64:.+]] = arith.index_cast %[[BASE_IDX]] : index to i64
+// CHECK: xegpu.store %[[VALS]], %[[BASE_I64]]{{\[}}%[[LIN]]{{\]}}, %[[MASK]] : vector<8xf16>, i64, vector<8xindex>, vector<8xi1>
// CHECK: gpu.return
}
More information about the Mlir-commits
mailing list