[Mlir-commits] [mlir] [MLIR][XeGPU] Scattered ops sg-to-wi distribution (PR #154949)

Artem Kroviakov llvmlistbot at llvm.org
Mon Aug 25 11:06:48 PDT 2025


https://github.com/akroviakov updated https://github.com/llvm/llvm-project/pull/154949

>From a081c8db7067d835da4d5d01be6d318f1987745c Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Sat, 23 Aug 2025 10:22:34 +0000
Subject: [PATCH 1/2] [MLIR][XeGPU] Scattered ops sg-to-wi distribution

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 141 +++++++++++++++++-
 .../Dialect/XeGPU/subgroup-distribute.mlir    |  42 ++++++
 2 files changed, 179 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 8e47968609d32..647b11b3706e1 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -807,6 +807,136 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
   }
 };
 
+struct StoreDistribution final : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    auto yield = cast<gpu::YieldOp>(
+        warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+    Operation *lastNode = yield->getPrevNode();
+    auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
+    if (!storeScatterOp)
+      return failure();
+    else if (!storeScatterOp.getOffsets())
+      return rewriter.notifyMatchFailure(storeScatterOp,
+                                         "Store op must have offsets argument");
+    else if (cast<VectorType>(storeScatterOp.getOffsets().getType())
+                 .getRank() != 1)
+      return rewriter.notifyMatchFailure(storeScatterOp,
+                                         "Expected 1D offsets vector");
+
+    VectorType storeVecTy =
+        cast<VectorType>(storeScatterOp.getValue().getType());
+    assert(storeVecTy.getRank() <= 2 &&
+           "Expected at most 2D result at SG level");
+    VectorType distStoreVecTy;
+    if (storeVecTy.getRank() == 2)
+      distStoreVecTy = VectorType::Builder(storeVecTy).dropDim(0);
+    else // rank 1
+      distStoreVecTy = VectorType::Builder(storeVecTy).setDim(0, 1);
+
+    SmallVector<size_t> newRetIndices;
+    SmallVector<Value> operands =
+        llvm::to_vector_of<Value>(storeScatterOp->getOperands());
+    SmallVector<Type> operandTypes =
+        llvm::to_vector_of<Type>(storeScatterOp->getOperandTypes());
+    operandTypes[0] = distStoreVecTy;
+
+    gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, operands, operandTypes, newRetIndices);
+    SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector(
+        newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
+
+    Value offsetsVec = newStoreScatterOpOperands[2];
+    Value maskVec = newStoreScatterOpOperands[3];
+
+    auto loc = newWarpOp.getLoc();
+    Value laneId = warpOp.getLaneid();
+    rewriter.setInsertionPointAfter(newWarpOp);
+    Value laneOffset =
+        vector::ExtractOp::create(rewriter, loc, offsetsVec, laneId);
+    laneOffset = vector::BroadcastOp::create(
+        rewriter, loc, VectorType::get({1}, laneOffset.getType()), laneOffset);
+    Value laneMask = vector::ExtractOp::create(rewriter, loc, maskVec, laneId);
+    laneMask = vector::BroadcastOp::create(
+        rewriter, loc, VectorType::get({1}, laneMask.getType()), laneMask);
+    newStoreScatterOpOperands[2] = laneOffset;
+    newStoreScatterOpOperands[3] = laneMask;
+
+    xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
+        rewriter, loc, TypeRange{}, newStoreScatterOpOperands,
+        storeScatterOp->getAttrs());
+    xegpu::removeLayoutAttrs(newOp);
+    rewriter.eraseOp(storeScatterOp);
+    return success();
+  }
+};
+
+struct LoadDistribution final : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *yieldOperand = getWarpResult(warpOp, [&](Operation *op) {
+      if (!isa<xegpu::LoadGatherOp>(op))
+        return false;
+      auto yield = cast<gpu::YieldOp>(
+          warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+      return yield->getPrevNode() == op;
+    });
+    if (!yieldOperand)
+      return rewriter.notifyMatchFailure(
+          warpOp, "warp result is not a xegpu::LoadGatherOp op");
+
+    auto loadGatherOp =
+        yieldOperand->get().getDefiningOp<xegpu::LoadGatherOp>();
+    if (!loadGatherOp.getOffsets())
+      return rewriter.notifyMatchFailure(loadGatherOp,
+                                         "Load op must have offsets argument");
+    else if (cast<VectorType>(loadGatherOp.getOffsets().getType()).getRank() !=
+             1)
+      return rewriter.notifyMatchFailure(loadGatherOp,
+                                         "Expected 1D offsets vector");
+
+    SmallVector<size_t> newRetIndices;
+    SmallVector<Value> operands =
+        llvm::to_vector_of<Value>(loadGatherOp->getOperands());
+    SmallVector<Type> operandTypes =
+        llvm::to_vector_of<Type>(loadGatherOp->getOperandTypes());
+
+    gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, operands, operandTypes, newRetIndices);
+
+    SmallVector<Value> newLoadGatherOperands = llvm::map_to_vector(
+        newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
+
+    const unsigned operandIdx = yieldOperand->getOperandNumber();
+    VectorType loadVecTy =
+        cast<VectorType>(warpOp.getResult(operandIdx).getType());
+    assert(loadVecTy.getRank() == 1 && "Expected a distributed vector");
+
+    Value offsetsVec = newLoadGatherOperands[1];
+    Value maskVec = newLoadGatherOperands[2];
+    auto loc = newWarpOp.getLoc();
+    Value laneId = warpOp.getLaneid();
+    rewriter.setInsertionPointAfter(newWarpOp);
+    Value laneOffset =
+        vector::ExtractOp::create(rewriter, loc, offsetsVec, laneId);
+    laneOffset = vector::BroadcastOp::create(
+        rewriter, loc, VectorType::get({1}, laneOffset.getType()), laneOffset);
+    Value laneMask = vector::ExtractOp::create(rewriter, loc, maskVec, laneId);
+    laneMask = vector::BroadcastOp::create(
+        rewriter, loc, VectorType::get({1}, laneMask.getType()), laneMask);
+    newLoadGatherOperands[1] = laneOffset;
+    newLoadGatherOperands[2] = laneMask;
+
+    xegpu::LoadGatherOp newOp = rewriter.create<xegpu::LoadGatherOp>(
+        loc, loadVecTy, newLoadGatherOperands, loadGatherOp->getAttrs());
+    Value distributedVal = newWarpOp.getResult(operandIdx);
+    rewriter.replaceAllUsesWith(distributedVal, newOp->getResult(0));
+    return success();
+  }
+};
+
 } // namespace
 
 namespace {
@@ -819,10 +949,11 @@ struct XeGPUSubgroupDistributePass final
 
 void xegpu::populateXeGPUSubgroupDistributePatterns(
     RewritePatternSet &patterns) {
-  patterns.add<CreateNdDescDistribution, StoreNdDistribution,
-               LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
-               UpdateNdOffsetDistribution, GpuBarrierDistribution>(
-      patterns.getContext());
+  patterns
+      .add<CreateNdDescDistribution, StoreNdDistribution, LoadNdDistribution,
+           DpasDistribution, PrefetchNdDistribution, UpdateNdOffsetDistribution,
+           GpuBarrierDistribution, LoadDistribution, StoreDistribution>(
+          patterns.getContext());
 }
 
 void XeGPUSubgroupDistributePass::runOnOperation() {
@@ -837,6 +968,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
       if (!isa<VectorType>(operand.get().getType()))
         continue;
 
+      if (isa<xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op))
+        continue;
       xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operand);
       if (!layout) {
         op->emitError("Could not find layout attribute for operand ")
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 54ef56e013abb..b319162dc3f25 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -319,3 +319,45 @@ gpu.module @test {
     gpu.return
   }
 }
+
+// -----
+// CHECK-LABEL: gpu.func @scatter_ops_chunksize({{.*}}) {
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
+// CHECK-NEXT: %[[LANE_ID:.*]] = gpu.lane_id
+// CHECK-NEXT: %[[LANE_OFFSET:.*]] = vector.extract %arg1[%[[LANE_ID]]] : index from vector<16xindex>
+// CHECK-NEXT: %[[LANE_OFFSET_VEC:.*]] = vector.broadcast %[[LANE_OFFSET]] : index to vector<1xindex>
+// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{chunk_size = 8 : i64, l1_hint = #xegpu.cache_hint<cached>,
+// CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
+// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{chunk_size = 8 : i64, l1_hint = #xegpu.cache_hint<cached>,
+// CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+gpu.module @test {
+  gpu.func @scatter_ops_chunksize(%src: memref<256xf16>, %offset1: vector<16xindex>) {
+    %1 = arith.constant dense<1>: vector<16xi1>
+    %3 = xegpu.load %src[%offset1], %1 <{chunk_size=8, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+        : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+    xegpu.store %3, %src[%offset1], %1 <{chunk_size=8, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+        : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+    gpu.return
+  }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @scatter_ops({{.*}}) {
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
+// CHECK-NEXT: %[[LANE_ID:.*]] = gpu.lane_id
+// CHECK-NEXT: %[[LANE_OFFSET:.*]] = vector.extract %arg1[%[[LANE_ID]]] : index from vector<16xindex>
+// CHECK-NEXT: %[[LANE_OFFSET_VEC:.*]] = vector.broadcast %[[LANE_OFFSET]] : index to vector<1xindex>
+// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{l1_hint = #xegpu.cache_hint<cached>,
+// CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
+// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{l1_hint = #xegpu.cache_hint<cached>,
+// CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+gpu.module @test {
+  gpu.func @scatter_ops(%src: memref<256xf16>, %offset1: vector<16xindex>) {
+    %1 = arith.constant dense<1>: vector<16xi1>
+    %3 = xegpu.load %src[%offset1], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+        : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
+    xegpu.store %3, %src[%offset1], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+        : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+    gpu.return
+  }
+}

>From 20c04fc537856721a5fae47b42b915032dbede17 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Mon, 25 Aug 2025 18:04:16 +0000
Subject: [PATCH 2/2] Assume distributable offset and mask producers

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 54 ++++++-------------
 .../Dialect/XeGPU/subgroup-distribute.mlir    | 30 +++++------
 2 files changed, 31 insertions(+), 53 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 647b11b3706e1..63b5b6559dfdd 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -817,14 +817,14 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
     auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
     if (!storeScatterOp)
       return failure();
-    else if (!storeScatterOp.getOffsets())
+    if (!storeScatterOp.getOffsets())
       return rewriter.notifyMatchFailure(storeScatterOp,
                                          "Store op must have offsets argument");
-    else if (cast<VectorType>(storeScatterOp.getOffsets().getType())
-                 .getRank() != 1)
+    VectorType offsetsTy =
+        cast<VectorType>(storeScatterOp.getOffsets().getType());
+    if (offsetsTy.getRank() != 1)
       return rewriter.notifyMatchFailure(storeScatterOp,
                                          "Expected 1D offsets vector");
-
     VectorType storeVecTy =
         cast<VectorType>(storeScatterOp.getValue().getType());
     assert(storeVecTy.getRank() <= 2 &&
@@ -836,33 +836,22 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
       distStoreVecTy = VectorType::Builder(storeVecTy).setDim(0, 1);
 
     SmallVector<size_t> newRetIndices;
-    SmallVector<Value> operands =
-        llvm::to_vector_of<Value>(storeScatterOp->getOperands());
+    SmallVector<Value> operands = storeScatterOp->getOperands();
     SmallVector<Type> operandTypes =
         llvm::to_vector_of<Type>(storeScatterOp->getOperandTypes());
     operandTypes[0] = distStoreVecTy;
+    // Assume offset and mask pproducers will be distributed as well.
+    operandTypes[2] = VectorType::get({1}, getElementTypeOrSelf(offsetsTy));
+    operandTypes[3] = VectorType::get(
+        {1}, getElementTypeOrSelf(storeScatterOp.getMask().getType()));
 
     gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
         rewriter, warpOp, operands, operandTypes, newRetIndices);
     SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector(
         newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
 
-    Value offsetsVec = newStoreScatterOpOperands[2];
-    Value maskVec = newStoreScatterOpOperands[3];
-
     auto loc = newWarpOp.getLoc();
-    Value laneId = warpOp.getLaneid();
     rewriter.setInsertionPointAfter(newWarpOp);
-    Value laneOffset =
-        vector::ExtractOp::create(rewriter, loc, offsetsVec, laneId);
-    laneOffset = vector::BroadcastOp::create(
-        rewriter, loc, VectorType::get({1}, laneOffset.getType()), laneOffset);
-    Value laneMask = vector::ExtractOp::create(rewriter, loc, maskVec, laneId);
-    laneMask = vector::BroadcastOp::create(
-        rewriter, loc, VectorType::get({1}, laneMask.getType()), laneMask);
-    newStoreScatterOpOperands[2] = laneOffset;
-    newStoreScatterOpOperands[3] = laneMask;
-
     xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
         rewriter, loc, TypeRange{}, newStoreScatterOpOperands,
         storeScatterOp->getAttrs());
@@ -892,16 +881,20 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
     if (!loadGatherOp.getOffsets())
       return rewriter.notifyMatchFailure(loadGatherOp,
                                          "Load op must have offsets argument");
-    else if (cast<VectorType>(loadGatherOp.getOffsets().getType()).getRank() !=
-             1)
+    VectorType offsetsTy =
+        cast<VectorType>(loadGatherOp.getOffsets().getType());
+    if (offsetsTy.getRank() != 1)
       return rewriter.notifyMatchFailure(loadGatherOp,
                                          "Expected 1D offsets vector");
 
     SmallVector<size_t> newRetIndices;
-    SmallVector<Value> operands =
-        llvm::to_vector_of<Value>(loadGatherOp->getOperands());
+    SmallVector<Value> operands = loadGatherOp->getOperands();
     SmallVector<Type> operandTypes =
         llvm::to_vector_of<Type>(loadGatherOp->getOperandTypes());
+    // Assume offset and mask pproducers will be distributed as well.
+    operandTypes[1] = VectorType::get({1}, getElementTypeOrSelf(offsetsTy));
+    operandTypes[2] = VectorType::get(
+        {1}, getElementTypeOrSelf(loadGatherOp.getMask().getType()));
 
     gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
         rewriter, warpOp, operands, operandTypes, newRetIndices);
@@ -914,21 +907,8 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
         cast<VectorType>(warpOp.getResult(operandIdx).getType());
     assert(loadVecTy.getRank() == 1 && "Expected a distributed vector");
 
-    Value offsetsVec = newLoadGatherOperands[1];
-    Value maskVec = newLoadGatherOperands[2];
     auto loc = newWarpOp.getLoc();
-    Value laneId = warpOp.getLaneid();
     rewriter.setInsertionPointAfter(newWarpOp);
-    Value laneOffset =
-        vector::ExtractOp::create(rewriter, loc, offsetsVec, laneId);
-    laneOffset = vector::BroadcastOp::create(
-        rewriter, loc, VectorType::get({1}, laneOffset.getType()), laneOffset);
-    Value laneMask = vector::ExtractOp::create(rewriter, loc, maskVec, laneId);
-    laneMask = vector::BroadcastOp::create(
-        rewriter, loc, VectorType::get({1}, laneMask.getType()), laneMask);
-    newLoadGatherOperands[1] = laneOffset;
-    newLoadGatherOperands[2] = laneMask;
-
     xegpu::LoadGatherOp newOp = rewriter.create<xegpu::LoadGatherOp>(
         loc, loadVecTy, newLoadGatherOperands, loadGatherOp->getAttrs());
     Value distributedVal = newWarpOp.getResult(operandIdx);
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index b319162dc3f25..1c4684681b62b 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -323,19 +323,18 @@ gpu.module @test {
 // -----
 // CHECK-LABEL: gpu.func @scatter_ops_chunksize({{.*}}) {
 // CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
-// CHECK-NEXT: %[[LANE_ID:.*]] = gpu.lane_id
-// CHECK-NEXT: %[[LANE_OFFSET:.*]] = vector.extract %arg1[%[[LANE_ID]]] : index from vector<16xindex>
-// CHECK-NEXT: %[[LANE_OFFSET_VEC:.*]] = vector.broadcast %[[LANE_OFFSET]] : index to vector<1xindex>
-// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{chunk_size = 8 : i64, l1_hint = #xegpu.cache_hint<cached>,
+// CHECK-NEXT: %[[LANE_OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
+// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64, l1_hint = #xegpu.cache_hint<cached>,
 // CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
-// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{chunk_size = 8 : i64, l1_hint = #xegpu.cache_hint<cached>,
+// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64, l1_hint = #xegpu.cache_hint<cached>,
 // CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
 gpu.module @test {
-  gpu.func @scatter_ops_chunksize(%src: memref<256xf16>, %offset1: vector<16xindex>) {
+  gpu.func @scatter_ops_chunksize(%src: memref<256xf16>) {
     %1 = arith.constant dense<1>: vector<16xi1>
-    %3 = xegpu.load %src[%offset1], %1 <{chunk_size=8, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+    %offset = arith.constant dense<12> : vector<16xindex>
+    %3 = xegpu.load %src[%offset], %1 <{chunk_size=8, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
         : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
-    xegpu.store %3, %src[%offset1], %1 <{chunk_size=8, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+    xegpu.store %3, %src[%offset], %1 <{chunk_size=8, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
         : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
     gpu.return
   }
@@ -344,19 +343,18 @@ gpu.module @test {
 // -----
 // CHECK-LABEL: gpu.func @scatter_ops({{.*}}) {
 // CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
-// CHECK-NEXT: %[[LANE_ID:.*]] = gpu.lane_id
-// CHECK-NEXT: %[[LANE_OFFSET:.*]] = vector.extract %arg1[%[[LANE_ID]]] : index from vector<16xindex>
-// CHECK-NEXT: %[[LANE_OFFSET_VEC:.*]] = vector.broadcast %[[LANE_OFFSET]] : index to vector<1xindex>
-// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{l1_hint = #xegpu.cache_hint<cached>,
+// CHECK-NEXT: %[[LANE_OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
+// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET]]], %[[MASK]] <{l1_hint = #xegpu.cache_hint<cached>,
 // CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
-// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{l1_hint = #xegpu.cache_hint<cached>,
+// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET]]], %[[MASK]] <{l1_hint = #xegpu.cache_hint<cached>,
 // CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
 gpu.module @test {
-  gpu.func @scatter_ops(%src: memref<256xf16>, %offset1: vector<16xindex>) {
+  gpu.func @scatter_ops(%src: memref<256xf16>) {
     %1 = arith.constant dense<1>: vector<16xi1>
-    %3 = xegpu.load %src[%offset1], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+    %offset = arith.constant dense<12> : vector<16xindex>
+    %3 = xegpu.load %src[%offset], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
         : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
-    xegpu.store %3, %src[%offset1], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+    xegpu.store %3, %src[%offset], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
         : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
     gpu.return
   }



More information about the Mlir-commits mailing list