[Mlir-commits] [mlir] [mlir][xegpu] Add SIMT distribution patterns for UpdateNdOffset and PrefetchNd ops. (PR #138033)

Charitha Saumya llvmlistbot at llvm.org
Mon May 5 13:23:24 PDT 2025


https://github.com/charithaintc updated https://github.com/llvm/llvm-project/pull/138033

>From d06477ef310adc2d6e9cab0df104f63d1641c1e8 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 30 Apr 2025 21:33:37 +0000
Subject: [PATCH 1/7] move work from old branch

---
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td |   2 +-
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 204 +++++++++++++++++-
 .../Dialect/XeGPU/subgroup-distribution.mlir  | 115 ++++++++++
 3 files changed, 319 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 5fa18754305ca..a892f701f724e 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -409,7 +409,7 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
 }
 
 def XeGPU_UpdateNdOffsetOp : XeGPU_Op<"update_nd_offset",
-                [AllTypesMatch<["TensorDesc", "result"]>]> {
+                [Pure, AllTypesMatch<["TensorDesc", "result"]>]> {
   let summary = "It updates the offsets for the TensorDesc.";
   let description = [{The op updates the offset of the given TensorDesc.
     The offsets are relative offset to the current position in the number
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 019032f7743bf..4f8fa7432b7d5 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -301,6 +301,10 @@ class LayoutInfoPropagation
                              ArrayRef<LayoutInfoLattice *> operands,
                              ArrayRef<const LayoutInfoLattice *> results);
 
+  void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
+                         ArrayRef<LayoutInfoLattice *> operands,
+                         ArrayRef<const LayoutInfoLattice *> results);
+
   void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
                                    ArrayRef<LayoutInfoLattice *> operands,
                                    ArrayRef<const LayoutInfoLattice *> results);
@@ -352,6 +356,9 @@ LogicalResult LayoutInfoPropagation::visitOperation(
       .Case<xegpu::UpdateNdOffsetOp>([&](auto updateNdOffsetOp) {
         visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
       })
+      .Case<xegpu::PrefetchNdOp>([&](auto prefetchNdOp) {
+        visitPrefetchNdOp(prefetchNdOp, operands, results);
+      })
       // No need to propagate the layout to operands in CreateNdDescOp because
       // they are scalars (offsets, sizes, etc.).
       .Case<xegpu::CreateNdDescOp>([&](auto createNdDescOp) {})
@@ -381,6 +388,18 @@ LogicalResult LayoutInfoPropagation::visitOperation(
   return success();
 }
 
+void LayoutInfoPropagation::visitPrefetchNdOp(
+    xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
+    ArrayRef<const LayoutInfoLattice *> results) {
+  // Here we assign the default layout to the tensor descriptor operand of
+  // prefetch.
+  auto tdescTy = prefetch.getTensorDescType();
+  auto prefetchLayout = getDefaultLayoutInfo(
+      VectorType::get(tdescTy.getShape(), tdescTy.getElementType()));
+  // Propagate the layout to the source tensor descriptor.
+  propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
+}
+
 void LayoutInfoPropagation::visitVectorMultiReductionOp(
     vector::MultiDimReductionOp reduction,
     ArrayRef<LayoutInfoLattice *> operands,
@@ -1412,6 +1431,174 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
   }
 };
 
+/// Sink an update_nd_offset op feeding into yield op of an enclosing
+/// `gpu.warp_execute_on_lane_0` region. The warp op will still contain the
+/// original op that will not be used by the yield op (and should be cleaned
+/// up later). The yield op will bypass the updateOp's arguments. The tensor
+/// descriptor type is not distributed. Appropriate cast ops are inserted if
+/// the distributed types does not match expected xegpu SIMT types.
+/// Example:
+/// ```
+///   #lo0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
+///   %r = gpu.warp_execute_on_lane_0(%laneid) ->
+///                   (!xegpu.tensor_desc<4x8xf32, #lo0>) {
+///     ...
+///     %update = xegpu.update_nd_offset %arg0, [%c32, %c16]:
+///     !xegpu.tensor_desc<4x8xf32, #lo0>
+///     gpu.yield %update
+///   }
+///   ...
+/// ```
+/// To
+/// ```
+///   %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>,
+///   !xegpu.tensor_desc<4x8xf32, #lo0>) {
+///     ...
+///     %dead = xegpu.update_nd_offset %arg0, [%c32, %c16]:
+///     !xegpu.tensor_desc<4x8xf32, #lo0> gpu.yield %dead, %arg0
+///     gup.yield %dead, %arg0, %c32, %c16
+///   }
+///   %0 = xegpu.unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
+///        #lo0> -> !xegpu.tensor_desc<4x8xf32>
+///   %1 = xegpu.update_nd_offset %0, [%c32, %c16]:
+///     !xegpu.tensor_desc<4x8xf32>
+///   ...
+/// ```
+struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *operand =
+        getWarpResult(subgroupOp, llvm::IsaPred<xegpu::UpdateNdOffsetOp>);
+    if (!operand)
+      return rewriter.notifyMatchFailure(
+          subgroupOp, "warp result is not a xegpu::UpdateNdOffset op");
+    auto updateOp = operand->get().getDefiningOp<xegpu::UpdateNdOffsetOp>();
+    unsigned operandIdx = operand->getOperandNumber();
+    auto newTensorDescTy = dropLayouts(updateOp.getTensorDescType());
+
+    SmallVector<Value, 3> newYieldValues;
+    SmallVector<Type, 3> newYieldTypes;
+    for (auto operand : updateOp->getOperands()) {
+      newYieldValues.push_back(operand);
+      if (isa<xegpu::TensorDescType>(operand.getType())) {
+        newYieldTypes.push_back(newTensorDescTy);
+      } else {
+        newYieldTypes.push_back(operand.getType());
+      }
+    }
+    SmallVector<size_t> newRetIndices;
+    gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
+    rewriter.setInsertionPointAfter(newWarpOp);
+    SmallVector<Value> newUpdateOperands;
+    for (auto i : newRetIndices) {
+      if (isa<xegpu::TensorDescType>(newWarpOp.getResult(i).getType())) {
+        newUpdateOperands.push_back(resolveDistributedTy(
+            newWarpOp.getResult(i), newTensorDescTy, rewriter));
+      } else {
+        newUpdateOperands.push_back(newWarpOp.getResult(i));
+      }
+    }
+    auto newUpdateOp = rewriter.create<xegpu::UpdateNdOffsetOp>(
+        newWarpOp.getLoc(), newTensorDescTy, newUpdateOperands,
+        removeTemporaryLayoutAttributes(updateOp->getAttrs()));
+    Value distributedVal = newWarpOp.getResult(operandIdx);
+    rewriter.replaceAllUsesWith(distributedVal, newUpdateOp);
+    return success();
+  }
+};
+
+struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
+                                PatternRewriter &rewriter) const override {
+    auto yield = cast<gpu::YieldOp>(
+        subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
+    Operation *lastNode = yield->getPrevNode();
+    auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
+    if (!prefetchOp)
+      return failure();
+    auto layout = prefetchOp.getTensorDescType().getLayoutAttr();
+    if (!layout)
+      return rewriter.notifyMatchFailure(
+          prefetchOp, "the source tensor descriptor lacks layout attribute");
+
+    SmallVector<Value, 1> newYieldValues = {prefetchOp.getTensorDesc()};
+    SmallVector<Type, 1> newYieldTypes = {prefetchOp.getTensorDescType()};
+    SmallVector<size_t> newRetIndices;
+    gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
+
+    auto newTensorDescTy = dropLayouts(prefetchOp.getTensorDescType());
+    rewriter.setInsertionPointAfter(newWarpOp);
+    SmallVector<Value> newPrefetchOperands = {resolveDistributedTy(
+        newWarpOp.getResult(newRetIndices[0]), newTensorDescTy, rewriter)};
+    rewriter.create<xegpu::PrefetchNdOp>(
+        newWarpOp.getLoc(), TypeRange{}, newPrefetchOperands,
+        removeTemporaryLayoutAttributes(prefetchOp->getAttrs()));
+    rewriter.eraseOp(prefetchOp);
+    return success();
+  }
+};
+
+/// Generic pattern for sinking a GPU index operations feeding into yield op
+/// of an enclosing `gpu.warp_execute_on_lane_0` region. The original index op
+/// becomes dead and an equivalent copy of the index op is created outside the
+/// warp op.
+/// Example:
+/// ```
+///   %r = gpu.warp_execute_on_lane_0(%laneid) -> (index) {
+///     ...
+///     %index = gpu.block_id x : index
+///     gpu.yield %index
+///   }
+///   ...
+/// ```
+/// To
+/// ```
+///   %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (index) {
+///     ...
+///     %dead = gpu.block_id x : index
+///     gpu.yield %dead
+///   }
+///   %0 = gpu.block_id x : index
+///   ...
+/// ```
+template <typename IndexOp>
+struct GpuIndexOpDistribution final : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
+                                PatternRewriter &rewriter) const override {
+    auto operand = getWarpResult(subgroupOp, llvm::IsaPred<IndexOp>);
+    if (!operand)
+      return rewriter.notifyMatchFailure(subgroupOp,
+                                         "warp result is not a gpu index op");
+    auto indexOp = operand->template get().template getDefiningOp<IndexOp>();
+    unsigned operandIdx = operand->template getOperandNumber();
+    SmallVector<Value, 3> newYieldValues;
+    SmallVector<Type, 3> newYieldTypes;
+    for (auto operand : indexOp->template getOperands()) {
+      newYieldValues.push_back(operand);
+      newYieldTypes.push_back(operand.getType());
+    }
+    SmallVector<size_t> newRetIndices;
+    gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
+    rewriter.setInsertionPointAfter(newWarpOp);
+    SmallVector<Value> newIndexOperands;
+    for (auto i : newRetIndices) {
+      newIndexOperands.push_back(newWarpOp.getResult(i));
+    }
+    auto newIndexOp = rewriter.create<IndexOp>(
+        newWarpOp.getLoc(), newIndexOperands,
+        removeTemporaryLayoutAttributes(indexOp->template getAttrs()));
+    Value distributedVal = newWarpOp.getResult(operandIdx);
+    rewriter.replaceAllUsesWith(distributedVal, newIndexOp);
+    return success();
+  }
+};
+
 } // namespace
 
 namespace {
@@ -1430,7 +1617,22 @@ struct XeGPUSubgroupDistributePass final
 void xegpu::populateXeGPUSubgroupDistributePatterns(
     RewritePatternSet &patterns) {
   patterns.add<CreateNdDescDistribution, StoreNdDistribution,
-               LoadNdDistribution, DpasDistribution>(patterns.getContext());
+               LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
+               UpdateNdOffsetDistribution>(patterns.getContext());
+  // TODO: Is this the right place to add these patterns?
+  patterns.add<GpuIndexOpDistribution<gpu::BlockIdOp>,
+               GpuIndexOpDistribution<gpu::BlockDimOp>,
+               GpuIndexOpDistribution<gpu::SubgroupIdOp>,
+               GpuIndexOpDistribution<gpu::SubgroupSizeOp>,
+               GpuIndexOpDistribution<gpu::NumSubgroupsOp>,
+               GpuIndexOpDistribution<gpu::ClusterDimOp>,
+               GpuIndexOpDistribution<gpu::ClusterDimBlocksOp>,
+               GpuIndexOpDistribution<gpu::ClusterIdOp>,
+               GpuIndexOpDistribution<gpu::ClusterBlockIdOp>,
+               GpuIndexOpDistribution<gpu::GridDimOp>,
+               GpuIndexOpDistribution<gpu::ThreadIdOp>,
+               GpuIndexOpDistribution<gpu::LaneIdOp>,
+               GpuIndexOpDistribution<gpu::GlobalIdOp>>(patterns.getContext());
 }
 
 void XeGPUSubgroupDistributePass::runOnOperation() {
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribution.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribution.mlir
index f8f2cd55c28d0..41f035f9b1fac 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribution.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribution.mlir
@@ -160,3 +160,118 @@ gpu.func @create_nd_tdesc_non_memref(%arg0: ui64, %arg1: ui64,
   gpu.return
 }
 }
+
+// -----
+// CHECK-LABEL: gpu.func @test_update_nd_offset_1d(
+// CHECK: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf32>) {
+// CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<1xf32>
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<256xf32> -> !xegpu.tensor_desc<16xf32>
+// CHECK: %[[T1:.*]] = xegpu.update_nd_offset %[[T0]], [%c32] : !xegpu.tensor_desc<16xf32>
+// CHECK: xegpu.store_nd %[[CST]], %[[T1]]  : vector<1xf32>, !xegpu.tensor_desc<16xf32>
+gpu.module @test {
+gpu.func @test_update_nd_offset_1d(%arg0: memref<256xf32>){
+  %c0 = arith.constant 0 : index
+  %c32 = arith.constant 32 : index
+  %1 = arith.constant dense<1.000000e+00> : vector<16xf32>
+  %0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<256xf32> -> !xegpu.tensor_desc<16xf32>
+  %2 = xegpu.update_nd_offset %0, [%c32] : !xegpu.tensor_desc<16xf32>
+  xegpu.store_nd %1, %2 : vector<16xf32>, !xegpu.tensor_desc<16xf32>
+  gpu.return
+}
+}
+
+// -----
+// CHECK-LABEL: gpu.func @test_update_nd_offset_2d
+// CHECK: %[[ARG0:[0-9a-zA-Z]+]]: memref<256x256xf32>) {
+// CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<16xf32>
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<256x256xf32> -> !xegpu.tensor_desc<16x16xf32>
+// CHECK: %[[T1:.*]] = xegpu.update_nd_offset %[[T0]], [%c32, %c32] : !xegpu.tensor_desc<16x16xf32>
+// CHECK: xegpu.store_nd %[[CST]], %[[T1]]  : vector<16xf32>, !xegpu.tensor_desc<16x16xf32>
+gpu.module @test {
+gpu.func @test_update_nd_offset_2d(%arg0: memref<256x256xf32>){
+  %c0 = arith.constant 0 : index
+  %c32 = arith.constant 32 : index
+  %1 = arith.constant dense<1.000000e+00> : vector<16x16xf32>
+  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<256x256xf32> -> !xegpu.tensor_desc<16x16xf32>
+  %2 = xegpu.update_nd_offset %0, [%c32, %c32] : !xegpu.tensor_desc<16x16xf32>
+  xegpu.store_nd %1, %2 : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32>
+  gpu.return
+}
+}
+
+// -----
+// CHECK-LABEL: gpu.func @test_prefetch_2d
+// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<256x256xf16>) {
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16>
+// CHECK: xegpu.prefetch_nd %[[T0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16>
+gpu.module @test {
+gpu.func @test_prefetch_2d(%arg0: memref<256x256xf16>){
+  %c0 = arith.constant 0 : index
+  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16>
+  xegpu.prefetch_nd %0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<16x16xf16>
+  gpu.return
+}
+}
+
+// -----
+// CHECK-LABEL: gpu.func @test_prefetch_1d
+// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<256xf16> -> !xegpu.tensor_desc<16xf16>
+// CHECK: xegpu.prefetch_nd %[[T0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16xf16>
+gpu.module @test {
+gpu.func @test_prefetch_1d(%arg0: memref<256xf16>){
+  %c0 = arith.constant 0 : index
+  %0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<256xf16> -> !xegpu.tensor_desc<16xf16>
+  xegpu.prefetch_nd %0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<16xf16>
+  gpu.return
+}
+}
+
+
+// -----
+// CHECK-LABEL: gpu.func @test_gemm_loop
+// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<1024x1024xf32>) {
+// CHECK: %[[BLOCK_ID_Y:.*]] = gpu.block_id y
+// CHECK: %[[Y_COORD:.*]] = arith.muli %[[BLOCK_ID_Y]], %c16 : index
+// CHECK: %[[BLOCK_ID_X:.*]] = gpu.block_id x
+// CHECK: %[[X_COORD:.*]] = arith.muli %[[BLOCK_ID_X]], %c8 : index
+// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG2]][%[[X_COORD]], %[[Y_COORD]]] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T2]] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
+// CHECK: %[[T4:.*]] = vector.shape_cast %[[T3]] : vector<8xf32> to vector<8x1xf32>
+// CHECK: %[[T5:.*]] = scf.for %[[K:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG4:.*]] = %[[T4]]) -> (vector<8x1xf32>) {
+// CHECK: %[[T10:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%[[K]], %[[Y_COORD]]] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16>
+// CHECK: %[[T11:.*]] = xegpu.load_nd %[[T10]] <{packed}> : !xegpu.tensor_desc<16x16xbf16> -> vector<16xbf16>
+// CHECK: %[[T12:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%[[X_COORD]], %[[K]]] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16>
+// CHECK: %[[T13:.*]] = xegpu.load_nd %[[T12]] : !xegpu.tensor_desc<8x16xbf16> -> vector<8xbf16>
+// CHECK: %[[T14:.*]] = vector.shape_cast %[[ARG4]] : vector<8x1xf32> to vector<8xf32>
+// CHECK: %[[T15:.*]] = xegpu.dpas %[[T13]], %[[T11]], %[[T14]] : vector<8xbf16>, vector<16xbf16>, vector<8xf32> -> vector<8xf32>
+// CHECK: %[[T16:.*]] = vector.shape_cast %[[T15]] : vector<8xf32> to vector<8x1xf32>
+// CHECK: scf.yield %[[T16]] : vector<8x1xf32>
+// CHECK: }
+// CHECK: %[[T8:.*]] = xegpu.create_nd_tdesc %[[ARG2]]{{.*}} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK: %[[T9:.*]] = vector.shape_cast %[[T5]] : vector<8x1xf32> to vector<8xf32>
+// CHECK: xegpu.store_nd %[[T9]], %[[T8]] : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
+gpu.module @test {
+gpu.func @test_gemm_loop(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x1024xbf16>, %arg2: memref<1024x1024xf32>){
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c8 = arith.constant 8 : index
+  %c1024 = arith.constant 1024 : index
+  %0 = gpu.block_id x
+  %1 = gpu.block_id y
+  %2 = arith.muli %0, %c8 : index
+  %3 = arith.muli %1, %c16 : index
+  %4 = xegpu.create_nd_tdesc %arg2[%2, %3] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32>
+  %5 = xegpu.load_nd %4 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
+  %6 = scf.for %arg3 = %c0 to %c1024 step %c16 iter_args(%arg4 = %5) -> (vector<8x16xf32>) {
+    %7 = xegpu.create_nd_tdesc %arg0[%2, %arg3] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16>
+    %8 = xegpu.create_nd_tdesc %arg1[%arg3, %3] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16>
+    %9 = xegpu.load_nd %7 : !xegpu.tensor_desc<8x16xbf16> -> vector<8x16xbf16>
+    %10 = xegpu.load_nd %8 : !xegpu.tensor_desc<16x16xbf16> -> vector<16x16xbf16>
+    %11 = xegpu.dpas %9, %10, %arg4 : vector<8x16xbf16>, vector<16x16xbf16>, vector<8x16xf32> -> vector<8x16xf32>
+    scf.yield %11 : vector<8x16xf32>
+  }
+  xegpu.store_nd %6, %4 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  gpu.return
+}
+}

>From d5d2713d13701db48d05e0a006c16fbe8a0fc2b9 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 30 Apr 2025 22:17:10 +0000
Subject: [PATCH 2/7] save work

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 20 ++++++++++---------
 .../Dialect/XeGPU/subgroup-distribution.mlir  | 20 +++++++++----------
 2 files changed, 21 insertions(+), 19 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 4f8fa7432b7d5..a6581a504d1e7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1475,11 +1475,12 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
           subgroupOp, "warp result is not a xegpu::UpdateNdOffset op");
     auto updateOp = operand->get().getDefiningOp<xegpu::UpdateNdOffsetOp>();
     unsigned operandIdx = operand->getOperandNumber();
-    auto newTensorDescTy = dropLayouts(updateOp.getTensorDescType());
+    xegpu::TensorDescType newTensorDescTy =
+        dropLayouts(updateOp.getTensorDescType());
 
     SmallVector<Value, 3> newYieldValues;
     SmallVector<Type, 3> newYieldTypes;
-    for (auto operand : updateOp->getOperands()) {
+    for (Value operand : updateOp->getOperands()) {
       newYieldValues.push_back(operand);
       if (isa<xegpu::TensorDescType>(operand.getType())) {
         newYieldTypes.push_back(newTensorDescTy);
@@ -1492,7 +1493,7 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
         rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
     rewriter.setInsertionPointAfter(newWarpOp);
     SmallVector<Value> newUpdateOperands;
-    for (auto i : newRetIndices) {
+    for (size_t i : newRetIndices) {
       if (isa<xegpu::TensorDescType>(newWarpOp.getResult(i).getType())) {
         newUpdateOperands.push_back(resolveDistributedTy(
             newWarpOp.getResult(i), newTensorDescTy, rewriter));
@@ -1519,7 +1520,7 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
     auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
     if (!prefetchOp)
       return failure();
-    auto layout = prefetchOp.getTensorDescType().getLayoutAttr();
+    xegpu::LayoutAttr layout = prefetchOp.getTensorDescType().getLayoutAttr();
     if (!layout)
       return rewriter.notifyMatchFailure(
           prefetchOp, "the source tensor descriptor lacks layout attribute");
@@ -1530,7 +1531,8 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
     gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
         rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
 
-    auto newTensorDescTy = dropLayouts(prefetchOp.getTensorDescType());
+    xegpu::TensorDescType newTensorDescTy =
+        dropLayouts(prefetchOp.getTensorDescType());
     rewriter.setInsertionPointAfter(newWarpOp);
     SmallVector<Value> newPrefetchOperands = {resolveDistributedTy(
         newWarpOp.getResult(newRetIndices[0]), newTensorDescTy, rewriter)};
@@ -1570,12 +1572,12 @@ struct GpuIndexOpDistribution final : public gpu::WarpDistributionPattern {
   using gpu::WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
                                 PatternRewriter &rewriter) const override {
-    auto operand = getWarpResult(subgroupOp, llvm::IsaPred<IndexOp>);
+    OpOperand *operand = getWarpResult(subgroupOp, llvm::IsaPred<IndexOp>);
     if (!operand)
       return rewriter.notifyMatchFailure(subgroupOp,
                                          "warp result is not a gpu index op");
-    auto indexOp = operand->template get().template getDefiningOp<IndexOp>();
-    unsigned operandIdx = operand->template getOperandNumber();
+    auto indexOp = operand->get().getDefiningOp<IndexOp>();
+    unsigned operandIdx = operand->getOperandNumber();
     SmallVector<Value, 3> newYieldValues;
     SmallVector<Type, 3> newYieldTypes;
     for (auto operand : indexOp->template getOperands()) {
@@ -1587,7 +1589,7 @@ struct GpuIndexOpDistribution final : public gpu::WarpDistributionPattern {
         rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
     rewriter.setInsertionPointAfter(newWarpOp);
     SmallVector<Value> newIndexOperands;
-    for (auto i : newRetIndices) {
+    for (size_t i : newRetIndices) {
       newIndexOperands.push_back(newWarpOp.getResult(i));
     }
     auto newIndexOp = rewriter.create<IndexOp>(
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribution.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribution.mlir
index 41f035f9b1fac..5d0665cb6e155 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribution.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribution.mlir
@@ -162,14 +162,14 @@ gpu.func @create_nd_tdesc_non_memref(%arg0: ui64, %arg1: ui64,
 }
 
 // -----
-// CHECK-LABEL: gpu.func @test_update_nd_offset_1d(
+// CHECK-LABEL: gpu.func @update_nd_offset_1d(
 // CHECK: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf32>) {
 // CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<1xf32>
 // CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<256xf32> -> !xegpu.tensor_desc<16xf32>
 // CHECK: %[[T1:.*]] = xegpu.update_nd_offset %[[T0]], [%c32] : !xegpu.tensor_desc<16xf32>
 // CHECK: xegpu.store_nd %[[CST]], %[[T1]]  : vector<1xf32>, !xegpu.tensor_desc<16xf32>
 gpu.module @test {
-gpu.func @test_update_nd_offset_1d(%arg0: memref<256xf32>){
+gpu.func @update_nd_offset_1d(%arg0: memref<256xf32>){
   %c0 = arith.constant 0 : index
   %c32 = arith.constant 32 : index
   %1 = arith.constant dense<1.000000e+00> : vector<16xf32>
@@ -181,14 +181,14 @@ gpu.func @test_update_nd_offset_1d(%arg0: memref<256xf32>){
 }
 
 // -----
-// CHECK-LABEL: gpu.func @test_update_nd_offset_2d
+// CHECK-LABEL: gpu.func @update_nd_offset_2d
 // CHECK: %[[ARG0:[0-9a-zA-Z]+]]: memref<256x256xf32>) {
 // CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<16xf32>
 // CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<256x256xf32> -> !xegpu.tensor_desc<16x16xf32>
 // CHECK: %[[T1:.*]] = xegpu.update_nd_offset %[[T0]], [%c32, %c32] : !xegpu.tensor_desc<16x16xf32>
 // CHECK: xegpu.store_nd %[[CST]], %[[T1]]  : vector<16xf32>, !xegpu.tensor_desc<16x16xf32>
 gpu.module @test {
-gpu.func @test_update_nd_offset_2d(%arg0: memref<256x256xf32>){
+gpu.func @update_nd_offset_2d(%arg0: memref<256x256xf32>){
   %c0 = arith.constant 0 : index
   %c32 = arith.constant 32 : index
   %1 = arith.constant dense<1.000000e+00> : vector<16x16xf32>
@@ -200,12 +200,12 @@ gpu.func @test_update_nd_offset_2d(%arg0: memref<256x256xf32>){
 }
 
 // -----
-// CHECK-LABEL: gpu.func @test_prefetch_2d
+// CHECK-LABEL: gpu.func @prefetch_2d
 // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<256x256xf16>) {
 // CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16>
 // CHECK: xegpu.prefetch_nd %[[T0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16>
 gpu.module @test {
-gpu.func @test_prefetch_2d(%arg0: memref<256x256xf16>){
+gpu.func @prefetch_2d(%arg0: memref<256x256xf16>){
   %c0 = arith.constant 0 : index
   %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16>
   xegpu.prefetch_nd %0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<16x16xf16>
@@ -214,12 +214,12 @@ gpu.func @test_prefetch_2d(%arg0: memref<256x256xf16>){
 }
 
 // -----
-// CHECK-LABEL: gpu.func @test_prefetch_1d
+// CHECK-LABEL: gpu.func @prefetch_1d
 // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
 // CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<256xf16> -> !xegpu.tensor_desc<16xf16>
 // CHECK: xegpu.prefetch_nd %[[T0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16xf16>
 gpu.module @test {
-gpu.func @test_prefetch_1d(%arg0: memref<256xf16>){
+gpu.func @prefetch_1d(%arg0: memref<256xf16>){
   %c0 = arith.constant 0 : index
   %0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<256xf16> -> !xegpu.tensor_desc<16xf16>
   xegpu.prefetch_nd %0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<16xf16>
@@ -229,7 +229,7 @@ gpu.func @test_prefetch_1d(%arg0: memref<256xf16>){
 
 
 // -----
-// CHECK-LABEL: gpu.func @test_gemm_loop
+// CHECK-LABEL: gpu.func @gemm_loop
 // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<1024x1024xf32>) {
 // CHECK: %[[BLOCK_ID_Y:.*]] = gpu.block_id y
 // CHECK: %[[Y_COORD:.*]] = arith.muli %[[BLOCK_ID_Y]], %c16 : index
@@ -252,7 +252,7 @@ gpu.func @test_prefetch_1d(%arg0: memref<256xf16>){
 // CHECK: %[[T9:.*]] = vector.shape_cast %[[T5]] : vector<8x1xf32> to vector<8xf32>
 // CHECK: xegpu.store_nd %[[T9]], %[[T8]] : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
 gpu.module @test {
-gpu.func @test_gemm_loop(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x1024xbf16>, %arg2: memref<1024x1024xf32>){
+gpu.func @gemm_loop(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x1024xbf16>, %arg2: memref<1024x1024xf32>){
   %c0 = arith.constant 0 : index
   %c16 = arith.constant 16 : index
   %c8 = arith.constant 8 : index

>From 6aa4aef979f9d52c9f424ce08083d8d43a44e6a0 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 1 May 2025 01:12:03 +0000
Subject: [PATCH 3/7] save work

---
 .../Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp    | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index a6581a504d1e7..e50ef2cede7ea 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1576,11 +1576,11 @@ struct GpuIndexOpDistribution final : public gpu::WarpDistributionPattern {
     if (!operand)
       return rewriter.notifyMatchFailure(subgroupOp,
                                          "warp result is not a gpu index op");
-    auto indexOp = operand->get().getDefiningOp<IndexOp>();
+    Operation *indexOp = operand->get().getDefiningOp<IndexOp>();
     unsigned operandIdx = operand->getOperandNumber();
     SmallVector<Value, 3> newYieldValues;
     SmallVector<Type, 3> newYieldTypes;
-    for (auto operand : indexOp->template getOperands()) {
+    for (Value operand : indexOp->getOperands()) {
       newYieldValues.push_back(operand);
       newYieldTypes.push_back(operand.getType());
     }
@@ -1594,7 +1594,7 @@ struct GpuIndexOpDistribution final : public gpu::WarpDistributionPattern {
     }
     auto newIndexOp = rewriter.create<IndexOp>(
         newWarpOp.getLoc(), newIndexOperands,
-        removeTemporaryLayoutAttributes(indexOp->template getAttrs()));
+        removeTemporaryLayoutAttributes(indexOp->getAttrs()));
     Value distributedVal = newWarpOp.getResult(operandIdx);
     rewriter.replaceAllUsesWith(distributedVal, newIndexOp);
     return success();

>From 1649c52e72d85c558504c80e70840f1ceadb6345 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Mon, 5 May 2025 18:29:01 +0000
Subject: [PATCH 4/7] remove index ops

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 71 -------------------
 1 file changed, 71 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index e50ef2cede7ea..1a8c3a79ae515 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1544,63 +1544,6 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
   }
 };
 
-/// Generic pattern for sinking a GPU index operations feeding into yield op
-/// of an enclosing `gpu.warp_execute_on_lane_0` region. The original index op
-/// becomes dead and an equivalent copy of the index op is created outside the
-/// warp op.
-/// Example:
-/// ```
-///   %r = gpu.warp_execute_on_lane_0(%laneid) -> (index) {
-///     ...
-///     %index = gpu.block_id x : index
-///     gpu.yield %index
-///   }
-///   ...
-/// ```
-/// To
-/// ```
-///   %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (index) {
-///     ...
-///     %dead = gpu.block_id x : index
-///     gpu.yield %dead
-///   }
-///   %0 = gpu.block_id x : index
-///   ...
-/// ```
-template <typename IndexOp>
-struct GpuIndexOpDistribution final : public gpu::WarpDistributionPattern {
-  using gpu::WarpDistributionPattern::WarpDistributionPattern;
-  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
-                                PatternRewriter &rewriter) const override {
-    OpOperand *operand = getWarpResult(subgroupOp, llvm::IsaPred<IndexOp>);
-    if (!operand)
-      return rewriter.notifyMatchFailure(subgroupOp,
-                                         "warp result is not a gpu index op");
-    Operation *indexOp = operand->get().getDefiningOp<IndexOp>();
-    unsigned operandIdx = operand->getOperandNumber();
-    SmallVector<Value, 3> newYieldValues;
-    SmallVector<Type, 3> newYieldTypes;
-    for (Value operand : indexOp->getOperands()) {
-      newYieldValues.push_back(operand);
-      newYieldTypes.push_back(operand.getType());
-    }
-    SmallVector<size_t> newRetIndices;
-    gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
-    rewriter.setInsertionPointAfter(newWarpOp);
-    SmallVector<Value> newIndexOperands;
-    for (size_t i : newRetIndices) {
-      newIndexOperands.push_back(newWarpOp.getResult(i));
-    }
-    auto newIndexOp = rewriter.create<IndexOp>(
-        newWarpOp.getLoc(), newIndexOperands,
-        removeTemporaryLayoutAttributes(indexOp->getAttrs()));
-    Value distributedVal = newWarpOp.getResult(operandIdx);
-    rewriter.replaceAllUsesWith(distributedVal, newIndexOp);
-    return success();
-  }
-};
-
 } // namespace
 
 namespace {
@@ -1621,20 +1564,6 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
   patterns.add<CreateNdDescDistribution, StoreNdDistribution,
                LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
                UpdateNdOffsetDistribution>(patterns.getContext());
-  // TODO: Is this the right place to add these patterns?
-  patterns.add<GpuIndexOpDistribution<gpu::BlockIdOp>,
-               GpuIndexOpDistribution<gpu::BlockDimOp>,
-               GpuIndexOpDistribution<gpu::SubgroupIdOp>,
-               GpuIndexOpDistribution<gpu::SubgroupSizeOp>,
-               GpuIndexOpDistribution<gpu::NumSubgroupsOp>,
-               GpuIndexOpDistribution<gpu::ClusterDimOp>,
-               GpuIndexOpDistribution<gpu::ClusterDimBlocksOp>,
-               GpuIndexOpDistribution<gpu::ClusterIdOp>,
-               GpuIndexOpDistribution<gpu::ClusterBlockIdOp>,
-               GpuIndexOpDistribution<gpu::GridDimOp>,
-               GpuIndexOpDistribution<gpu::ThreadIdOp>,
-               GpuIndexOpDistribution<gpu::LaneIdOp>,
-               GpuIndexOpDistribution<gpu::GlobalIdOp>>(patterns.getContext());
 }
 
 void XeGPUSubgroupDistributePass::runOnOperation() {

>From 8e0c7fd42c842553eb66a224edceae54a6ad5cd8 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Mon, 5 May 2025 18:31:12 +0000
Subject: [PATCH 5/7] remove index ops

---
 .../Dialect/XeGPU/subgroup-distribution.mlir  | 49 -------------------
 1 file changed, 49 deletions(-)

diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribution.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribution.mlir
index 5d0665cb6e155..1df0520980766 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribution.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribution.mlir
@@ -226,52 +226,3 @@ gpu.func @prefetch_1d(%arg0: memref<256xf16>){
   gpu.return
 }
 }
-
-
-// -----
-// CHECK-LABEL: gpu.func @gemm_loop
-// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<1024x1024xf32>) {
-// CHECK: %[[BLOCK_ID_Y:.*]] = gpu.block_id y
-// CHECK: %[[Y_COORD:.*]] = arith.muli %[[BLOCK_ID_Y]], %c16 : index
-// CHECK: %[[BLOCK_ID_X:.*]] = gpu.block_id x
-// CHECK: %[[X_COORD:.*]] = arith.muli %[[BLOCK_ID_X]], %c8 : index
-// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG2]][%[[X_COORD]], %[[Y_COORD]]] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T2]] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
-// CHECK: %[[T4:.*]] = vector.shape_cast %[[T3]] : vector<8xf32> to vector<8x1xf32>
-// CHECK: %[[T5:.*]] = scf.for %[[K:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG4:.*]] = %[[T4]]) -> (vector<8x1xf32>) {
-// CHECK: %[[T10:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%[[K]], %[[Y_COORD]]] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16>
-// CHECK: %[[T11:.*]] = xegpu.load_nd %[[T10]] <{packed}> : !xegpu.tensor_desc<16x16xbf16> -> vector<16xbf16>
-// CHECK: %[[T12:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%[[X_COORD]], %[[K]]] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16>
-// CHECK: %[[T13:.*]] = xegpu.load_nd %[[T12]] : !xegpu.tensor_desc<8x16xbf16> -> vector<8xbf16>
-// CHECK: %[[T14:.*]] = vector.shape_cast %[[ARG4]] : vector<8x1xf32> to vector<8xf32>
-// CHECK: %[[T15:.*]] = xegpu.dpas %[[T13]], %[[T11]], %[[T14]] : vector<8xbf16>, vector<16xbf16>, vector<8xf32> -> vector<8xf32>
-// CHECK: %[[T16:.*]] = vector.shape_cast %[[T15]] : vector<8xf32> to vector<8x1xf32>
-// CHECK: scf.yield %[[T16]] : vector<8x1xf32>
-// CHECK: }
-// CHECK: %[[T8:.*]] = xegpu.create_nd_tdesc %[[ARG2]]{{.*}} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK: %[[T9:.*]] = vector.shape_cast %[[T5]] : vector<8x1xf32> to vector<8xf32>
-// CHECK: xegpu.store_nd %[[T9]], %[[T8]] : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
-gpu.module @test {
-gpu.func @gemm_loop(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x1024xbf16>, %arg2: memref<1024x1024xf32>){
-  %c0 = arith.constant 0 : index
-  %c16 = arith.constant 16 : index
-  %c8 = arith.constant 8 : index
-  %c1024 = arith.constant 1024 : index
-  %0 = gpu.block_id x
-  %1 = gpu.block_id y
-  %2 = arith.muli %0, %c8 : index
-  %3 = arith.muli %1, %c16 : index
-  %4 = xegpu.create_nd_tdesc %arg2[%2, %3] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32>
-  %5 = xegpu.load_nd %4 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
-  %6 = scf.for %arg3 = %c0 to %c1024 step %c16 iter_args(%arg4 = %5) -> (vector<8x16xf32>) {
-    %7 = xegpu.create_nd_tdesc %arg0[%2, %arg3] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16>
-    %8 = xegpu.create_nd_tdesc %arg1[%arg3, %3] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16>
-    %9 = xegpu.load_nd %7 : !xegpu.tensor_desc<8x16xbf16> -> vector<8x16xbf16>
-    %10 = xegpu.load_nd %8 : !xegpu.tensor_desc<16x16xbf16> -> vector<16x16xbf16>
-    %11 = xegpu.dpas %9, %10, %arg4 : vector<8x16xbf16>, vector<16x16xbf16>, vector<8x16xf32> -> vector<8x16xf32>
-    scf.yield %11 : vector<8x16xf32>
-  }
-  xegpu.store_nd %6, %4 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
-  gpu.return
-}
-}

>From a76de600b42aad7825d88dfac71b9b3fdc66ee5b Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Mon, 5 May 2025 18:52:34 +0000
Subject: [PATCH 6/7] add tests

---
 .../XeGPU/subgroup-map-propagation.mlir       | 59 +++++++++++++++++++
 1 file changed, 59 insertions(+)

diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
index a5468681e68dc..c7c82fc8dbb3c 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
@@ -561,3 +561,62 @@ func.func @test_vector_inner_reduction(%arg0: vector<16x16xf32>, %arg1: !xegpu.t
   xegpu.store_nd %0, %arg1  : vector<16xf32>, !xegpu.tensor_desc<16xf32>
   return
 }
+
+// -----
+// CHECK: function: update_nd_offset_1d:
+// CHECK: op    : %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<16xf32>
+// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
+// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}}[%{{.*}}] : memref<256xf32> -> !xegpu.tensor_desc<16xf32>
+// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
+// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.update_nd_offset %[[T0]], [%{{.*}}] : !xegpu.tensor_desc<16xf32>
+// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
+func.func @update_nd_offset_1d(%arg0: memref<256xf32>){
+  %c0 = arith.constant 0 : index
+  %c32 = arith.constant 32 : index
+  %1 = arith.constant dense<1.000000e+00> : vector<16xf32>
+  %0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<256xf32> -> !xegpu.tensor_desc<16xf32>
+  %2 = xegpu.update_nd_offset %0, [%c32] : !xegpu.tensor_desc<16xf32>
+  xegpu.store_nd %1, %2 : vector<16xf32>, !xegpu.tensor_desc<16xf32>
+  return
+}
+
+// -----
+// CHECK: function: update_nd_offset_2d:
+// CHECK: op    : %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<16x16xf32>
+// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
+// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}}[%{{.*}}] : memref<256x256xf32> -> !xegpu.tensor_desc<16x16xf32>
+// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
+// CHECK-NEXT: op    : %[[T1:.*]] = xegpu.update_nd_offset %[[T0]], [%{{.*}}] : !xegpu.tensor_desc<16x16xf32>
+// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
+func.func @update_nd_offset_2d(%arg0: memref<256x256xf32>){
+  %c0 = arith.constant 0 : index
+  %c32 = arith.constant 32 : index
+  %1 = arith.constant dense<1.000000e+00> : vector<16x16xf32>
+  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<256x256xf32> -> !xegpu.tensor_desc<16x16xf32>
+  %2 = xegpu.update_nd_offset %0, [%c32, %c32] : !xegpu.tensor_desc<16x16xf32>
+  xegpu.store_nd %1, %2 : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32>
+  return
+}
+
+// -----
+// CHECK: function: prefetch_2d:
+// CHECK: layout for result #0: Not assigned.
+// CHECK-NEXT: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}}[%{{.*}}] : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16>
+// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
+func.func @prefetch_2d(%arg0: memref<256x256xf16>){
+  %c0 = arith.constant 0 : index
+  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16>
+  xegpu.prefetch_nd %0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<16x16xf16>
+  return
+}
+
+// -----
+// CHECK: function: prefetch_1d:
+// CHECK: op    : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}}[%{{.*}}] : memref<256xf16> -> !xegpu.tensor_desc<16xf16>
+// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
+func.func @prefetch_1d(%arg0: memref<256xf16>){
+  %c0 = arith.constant 0 : index
+  %0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<256xf16> -> !xegpu.tensor_desc<16xf16>
+  xegpu.prefetch_nd %0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<16xf16>
+  return
+}

>From ee555d48c0b7cfb12127f7aba3c810fbf3ed1eac Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Mon, 5 May 2025 20:23:02 +0000
Subject: [PATCH 7/7] add tests

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 39 +++++++++++++++++--
 1 file changed, 35 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 1a8c3a79ae515..c7128666da7e8 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1192,7 +1192,7 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
     newStoreOperands.push_back(resolveDistributedTy(
         newWarpOp.getResult(newRetIndices[0]),
         storeNdDistributedValueTyOrFailure.value(), rewriter));
-    // For the tensor descriptor operand, the layout attibute is dropped after
+    // For the tensor descriptor operand, the layout attribute is dropped after
     // distribution. Types needs to be resolved in this case also.
     xegpu::TensorDescType distributedTensorDescTy =
         dropLayouts(storeOp.getTensorDescType());
@@ -1444,7 +1444,7 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
 ///                   (!xegpu.tensor_desc<4x8xf32, #lo0>) {
 ///     ...
 ///     %update = xegpu.update_nd_offset %arg0, [%c32, %c16]:
-///     !xegpu.tensor_desc<4x8xf32, #lo0>
+///       !xegpu.tensor_desc<4x8xf32, #lo0>
 ///     gpu.yield %update
 ///   }
 ///   ...
@@ -1455,7 +1455,7 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
 ///   !xegpu.tensor_desc<4x8xf32, #lo0>) {
 ///     ...
 ///     %dead = xegpu.update_nd_offset %arg0, [%c32, %c16]:
-///     !xegpu.tensor_desc<4x8xf32, #lo0> gpu.yield %dead, %arg0
+///       !xegpu.tensor_desc<4x8xf32, #lo0> gpu.yield %dead, %arg0
 ///     gup.yield %dead, %arg0, %c32, %c16
 ///   }
 ///   %0 = xegpu.unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
@@ -1475,6 +1475,7 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
           subgroupOp, "warp result is not a xegpu::UpdateNdOffset op");
     auto updateOp = operand->get().getDefiningOp<xegpu::UpdateNdOffsetOp>();
     unsigned operandIdx = operand->getOperandNumber();
+    // new update op does not have layout attribute.
     xegpu::TensorDescType newTensorDescTy =
         dropLayouts(updateOp.getTensorDescType());
 
@@ -1494,6 +1495,8 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
     rewriter.setInsertionPointAfter(newWarpOp);
     SmallVector<Value> newUpdateOperands;
     for (size_t i : newRetIndices) {
+      // For the tensor descriptor operand, the layout attribute is dropped
+      // after distribution. Types needs to be resolved in this case.
       if (isa<xegpu::TensorDescType>(newWarpOp.getResult(i).getType())) {
         newUpdateOperands.push_back(resolveDistributedTy(
             newWarpOp.getResult(i), newTensorDescTy, rewriter));
@@ -1501,6 +1504,7 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
         newUpdateOperands.push_back(newWarpOp.getResult(i));
       }
     }
+    // Create a new update op outside the warp op.
     auto newUpdateOp = rewriter.create<xegpu::UpdateNdOffsetOp>(
         newWarpOp.getLoc(), newTensorDescTy, newUpdateOperands,
         removeTemporaryLayoutAttributes(updateOp->getAttrs()));
@@ -1510,6 +1514,32 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
   }
 };
 
+/// Distribute a prefetch_nd op at the end of enclosing
+/// `gpu.warp_execute_on_lane_0`. In case arguments for the prefetch are passed
+/// through the warp op interface they would be propagated as returned values.
+/// Appropriate cast ops are inserted if the distributed types does not match
+/// expected xegpu SIMT types.
+///
+/// Example:
+///
+/// ```
+///   #lo0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
+///   gpu.warp_execute_on_lane_0(%laneid) -> () {
+///     ...
+///     xegpu.prefetch_nd %arg0 : !xegpu.tensor_desc<4x8xf32, #lo0>
+///   }
+/// ```
+/// To
+/// ```
+///   %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (
+//    !xegpu.tensor_desc<4x8xf32, #lo0>) {
+///     gpu.yield %arg0: !xegpu.tensor_desc<4x8xf32, #lo0>
+///   }
+///   %1 = unrealized_conversion_cast %r#0: !xegpu.tensor_desc<4x8xf32,
+///     #lo0> -> !xegpu.tensor_desc<4x8xf32>
+///   xegpu.prefetch_nd %0 : !xegpu.tensor_desc<4x8xf32>
+///
+/// ```
 struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
   using gpu::WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
@@ -1530,7 +1560,8 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
     SmallVector<size_t> newRetIndices;
     gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
         rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
-
+    // Create a new prefetch op outside the warp op with updated tensor
+    // descriptor type. Source tensor descriptor require type resolution.
     xegpu::TensorDescType newTensorDescTy =
         dropLayouts(prefetchOp.getTensorDescType());
     rewriter.setInsertionPointAfter(newWarpOp);



More information about the Mlir-commits mailing list