[Mlir-commits] [mlir] 4a9d038 - [MLIR][XeGPU] Distribute load_nd/store_nd/prefetch_nd with offsets from Wg to Sg (#153432)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Aug 18 09:45:32 PDT 2025


Author: Nishant Patel
Date: 2025-08-18T09:45:29-07:00
New Revision: 4a9d038acd637c5742e6d1622d4ad803059825bd

URL: https://github.com/llvm/llvm-project/commit/4a9d038acd637c5742e6d1622d4ad803059825bd
DIFF: https://github.com/llvm/llvm-project/commit/4a9d038acd637c5742e6d1622d4ad803059825bd.diff

LOG: [MLIR][XeGPU] Distribute load_nd/store_nd/prefetch_nd with offsets from Wg to Sg (#153432)

This PR adds pattern to distribute the load/store/prefetch nd ops with
offsets from workgroup to subgroup IR. This PR is part of the transition
to move offsets from create_nd to load/store/prefetch nd ops.

Create_nd PR : #152351

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
    mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
    mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
    mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index abc291c81a76c..eb54d6887681d 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -272,6 +272,11 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
 
   let builders = [
     OpBuilder<(ins "Value": $TensorDesc,
+                   "xegpu::CachePolicyAttr": $l1_hint,
+                   "xegpu::CachePolicyAttr": $l2_hint,
+                   "xegpu::CachePolicyAttr": $l3_hint)>,
+    OpBuilder<(ins "Value": $TensorDesc,
+                   "ArrayRef<OpFoldResult>": $offsets,
                    "xegpu::CachePolicyAttr": $l1_hint,
                    "xegpu::CachePolicyAttr": $l2_hint,
                    "xegpu::CachePolicyAttr": $l3_hint)>
@@ -348,6 +353,12 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
 
   let builders = [
     OpBuilder<(ins "Type": $value, "Value": $TensorDesc,
+                    "UnitAttr": $packed, "DenseI64ArrayAttr": $transpose,
+                    "xegpu::CachePolicyAttr": $l1_hint,
+                    "xegpu::CachePolicyAttr": $l2_hint,
+                    "xegpu::CachePolicyAttr": $l3_hint)>,
+    OpBuilder<(ins "Type": $value, "Value": $TensorDesc,
+                    "ArrayRef<OpFoldResult>": $offsets,
                     "UnitAttr": $packed, "DenseI64ArrayAttr": $transpose,
                     "xegpu::CachePolicyAttr": $l1_hint,
                     "xegpu::CachePolicyAttr": $l2_hint,
@@ -419,7 +430,12 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
     OpBuilder<(ins "Value": $value, "Value": $TensorDesc,
                    "xegpu::CachePolicyAttr": $l1_hint,
                    "xegpu::CachePolicyAttr": $l2_hint,
-                   "xegpu::CachePolicyAttr": $l3_hint)>
+                   "xegpu::CachePolicyAttr": $l3_hint)>,
+    OpBuilder<(ins "Value": $value, "Value": $TensorDesc,
+                  "ArrayRef<OpFoldResult>": $offsets,
+                  "xegpu::CachePolicyAttr": $l1_hint,
+                  "xegpu::CachePolicyAttr": $l2_hint,
+                  "xegpu::CachePolicyAttr": $l3_hint)>
   ];
 
 

diff  --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index eee0fdc7160de..906c71d8b8dad 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -385,6 +385,21 @@ void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
                l1_hint, l2_hint, l3_hint);
 }
 
+void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
+                         Value tensorDesc, ArrayRef<OpFoldResult> offsets,
+                         xegpu::CachePolicyAttr l1_hint,
+                         xegpu::CachePolicyAttr l2_hint,
+                         xegpu::CachePolicyAttr l3_hint) {
+  SmallVector<Value> dynamicOffsets;
+  SmallVector<int64_t> staticOffsets;
+  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+
+  auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
+
+  build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint,
+        l2_hint, l3_hint);
+}
+
 LogicalResult PrefetchNdOp::verify() {
   auto tdescTy = getTensorDescType();
   if (tdescTy.isScattered())
@@ -427,6 +442,22 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
                l3_hint);
 }
 
+void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
+                     Value tensorDesc, ArrayRef<OpFoldResult> offsets,
+                     UnitAttr packed, DenseI64ArrayAttr transpose,
+                     xegpu::CachePolicyAttr l1_hint,
+                     xegpu::CachePolicyAttr l2_hint,
+                     xegpu::CachePolicyAttr l3_hint) {
+  SmallVector<Value> dynamicOffsets;
+  SmallVector<int64_t> staticOffsets;
+  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+
+  auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
+
+  build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr,
+        packed, transpose, l1_hint, l2_hint, l3_hint);
+}
+
 LogicalResult LoadNdOp::verify() {
   auto tdescTy = getTensorDescType();
   auto valueTy = getType();
@@ -533,6 +564,21 @@ void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
                DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint);
 }
 
+void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
+                      Value tensorDesc, ArrayRef<OpFoldResult> offsets,
+                      xegpu::CachePolicyAttr l1_hint,
+                      xegpu::CachePolicyAttr l2_hint,
+                      xegpu::CachePolicyAttr l3_hint) {
+  SmallVector<Value> dynamicOffsets;
+  SmallVector<int64_t> staticOffsets;
+  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+
+  auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
+
+  build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr,
+        l1_hint, l2_hint, l3_hint);
+}
+
 LogicalResult StoreNdOp::verify() {
   auto dstTy = getTensorDescType(); // Tile
   auto valTy = getValueType();      // Vector

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index ecec186fe3fc9..8f1208e77ca5d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -182,16 +182,16 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
                                    layout.dropSgLayoutAndData());
 
     SmallVector<Value> newCreateNdOps;
-    SmallVector<OpFoldResult> wgOffsets = op.getMixedOffsets();
+    SmallVector<OpFoldResult> origOffsets = op.getMixedOffsets();
 
     for (auto tdescOffsets : *maybeTdescOffsets) {
       SmallVector<OpFoldResult> sgOffsets;
       size_t rank = tdescOffsets.size();
       for (size_t i = 0; i < rank; i++) {
-        size_t idx = wgOffsets.size() - rank + i;
+        size_t idx = origOffsets.size() - rank + i;
         Value add = rewriter.createOrFold<index::AddOp>(
             loc, tdescOffsets[i],
-            getValueOrCreateConstantIndexOp(rewriter, loc, wgOffsets[idx]));
+            getValueOrCreateConstantIndexOp(rewriter, loc, origOffsets[idx]));
         sgOffsets.push_back(add);
       }
 
@@ -296,6 +296,205 @@ struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
   }
 };
 
+// Utility function to compute global offsets for subgroup operations.
+// Returns a vector of new offsets for each subgroup, given the original op's
+// offsets and subgroup relative offsets.
+static SmallVector<SmallVector<OpFoldResult>>
+computeOffsets(Operation *op, ArrayRef<SmallVector<Value>> sgOffsetsList,
+               ArrayRef<OpFoldResult> origOffsets,
+               ConversionPatternRewriter &rewriter) {
+  SmallVector<SmallVector<OpFoldResult>> finalOffsets;
+  Location loc = op->getLoc();
+  for (const auto &sgOffsets : sgOffsetsList) {
+    SmallVector<OpFoldResult> newOffsets;
+    size_t rank = sgOffsets.size();
+    for (size_t i = 0; i < rank; i++) {
+      size_t idx = origOffsets.size() - rank + i;
+      Value add = rewriter.createOrFold<index::AddOp>(
+          loc, sgOffsets[i],
+          getValueOrCreateConstantIndexOp(rewriter, loc, origOffsets[idx]));
+      newOffsets.push_back(add);
+    }
+    finalOffsets.push_back(std::move(newOffsets));
+  }
+  return finalOffsets;
+}
+
+// Utility function to get sgShape, sgOffsetList for a given
+// op.
+template <typename OpTy, typename AdaptorTy>
+LogicalResult getSgOffsets(OpTy op, AdaptorTy adaptor,
+                           ConversionPatternRewriter &rewriter,
+                           SmallVector<int64_t> &sgShape,
+                           SmallVector<SmallVector<Value>> &sgOffsetList) {
+  int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
+  if (offsetSize == 0 && (!op.getConstOffsetsAttr()))
+    return failure();
+
+  Location loc = op.getLoc();
+  Value tdesc = op.getTensorDesc();
+  auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
+  if (!tdescTy)
+    return failure();
+  auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
+  if (!layout)
+    return failure();
+
+  SmallVector<int64_t> sgLayout;
+  auto sgLayoutAttr = layout.getSgLayout();
+  if (!sgLayoutAttr)
+    return rewriter.notifyMatchFailure(
+        op, "sgLayout attribute is required in layout");
+  sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
+
+  ArrayRef<int64_t> wgShape = tdescTy.getShape();
+  int count;
+  std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
+
+  // Get the subgroup ID
+  Value linearSgId =
+      gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
+
+  int64_t startOfRange = -1, endOfRange = -1;
+  bool sgIdRangeSpecified = isSgIdRangeSpecified(op, startOfRange, endOfRange);
+
+  if (sgIdRangeSpecified) {
+    int64_t sgCount = endOfRange - startOfRange;
+    if (computeProduct(sgLayout) != sgCount)
+      return rewriter.notifyMatchFailure(
+          op, "sg_layout size must match the sg_id_range");
+    Value startOfRangeVal =
+        rewriter.create<arith::ConstantIndexOp>(loc, startOfRange);
+    linearSgId =
+        rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal);
+  }
+
+  auto sgOffsets = layout.getOffsets(rewriter, loc, linearSgId, wgShape);
+  if (failed(sgOffsets))
+    return failure();
+
+  sgOffsetList = *sgOffsets;
+  return success();
+}
+
+template <typename OpTy>
+SmallVector<OpFoldResult> getOffsets(OpTy op,
+                                     ConversionPatternRewriter &rewriter) {
+  SmallVector<OpFoldResult> origOffsets;
+  if (auto constOffsets = op.getConstOffsetsAttr()) {
+    for (auto attr : constOffsets.asArrayRef())
+      origOffsets.push_back(rewriter.getIndexAttr(attr));
+  }
+  for (auto v : op.getOffsets())
+    origOffsets.push_back(v);
+  return origOffsets;
+}
+
+// This pattern transforms the LoadNdOp with explicit offsets to load
+// subgroup data.
+struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
+  using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    SmallVector<int64_t> sgShape;
+    SmallVector<SmallVector<Value>> sgOffsetList;
+
+    // Do the distribution from workgroup to subgroup and get subgroup offsets
+    if (failed(getSgOffsets(op, adaptor, rewriter, sgShape, sgOffsetList)))
+      return failure();
+
+    // Get the original workgroup offsets
+    SmallVector<OpFoldResult> origOffsets = getOffsets(op, rewriter);
+
+    // Calculate the final offsets for each subgroup
+    auto finalOffsets = computeOffsets(op, sgOffsetList, origOffsets, rewriter);
+
+    SmallVector<Value> newLoadOps;
+    for (auto [offsets, tdesc] :
+         llvm::zip(finalOffsets, adaptor.getTensorDesc())) {
+      VectorType newResTy = VectorType::get(
+          sgShape,
+          dyn_cast<xegpu::TensorDescType>(tdesc.getType()).getElementType());
+      auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(
+          op.getLoc(), newResTy, tdesc, offsets,
+          /*packed=*/nullptr,
+          /*transpose=*/nullptr, op.getL1HintAttr(), op.getL2HintAttr(),
+          op.getL3HintAttr());
+      newLoadOps.push_back(newLoadOp);
+    }
+    rewriter.replaceOpWithMultiple(op, {newLoadOps});
+    return success();
+  }
+};
+
+// This pattern transforms the StoreNdOp with explicit offsets to store
+// subgroup data.
+struct WgToSgStoreNdOpWithOffset
+    : public OpConversionPattern<xegpu::StoreNdOp> {
+  using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    SmallVector<int64_t> sgShape;
+    SmallVector<SmallVector<Value>> sgOffsetList;
+
+    // Do the distribution from workgroup to subgroup and get subgroup offsets
+    if (failed(getSgOffsets(op, adaptor, rewriter, sgShape, sgOffsetList)))
+      return failure();
+
+    // Get the original workgroup offsets
+    SmallVector<OpFoldResult> origOffsets = getOffsets(op, rewriter);
+
+    // Calculate the final offsets for each subgroup
+    auto finalOffsets = computeOffsets(op, sgOffsetList, origOffsets, rewriter);
+
+    for (auto [offsets, tdesc, value] :
+         llvm::zip(finalOffsets, adaptor.getTensorDesc(), adaptor.getValue())) {
+      rewriter.create<xegpu::StoreNdOp>(op.getLoc(), value, tdesc, offsets,
+                                        op.getL1HintAttr(), op.getL2HintAttr(),
+                                        op.getL3HintAttr());
+    }
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+// This pattern transforms the PrefetchNdOp with explicit offsets to prefetch
+// subgroup data.
+struct WgToSgPrefetchNdOpWithOffset
+    : public OpConversionPattern<xegpu::PrefetchNdOp> {
+  using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    SmallVector<int64_t> sgShape;
+    SmallVector<SmallVector<Value>> sgOffsetList;
+
+    // Do the distribution from workgroup to subgroup and get subgroup offsets
+    if (failed(getSgOffsets(op, adaptor, rewriter, sgShape, sgOffsetList)))
+      return failure();
+
+    // Get the original workgroup offsets
+    SmallVector<OpFoldResult> origOffsets = getOffsets(op, rewriter);
+
+    // Calculate the final offsets for each subgroup
+    auto finalOffsets = computeOffsets(op, sgOffsetList, origOffsets, rewriter);
+
+    for (auto [offsets, tdesc] :
+         llvm::zip(finalOffsets, adaptor.getTensorDesc())) {
+      rewriter.create<xegpu::PrefetchNdOp>(
+          op.getLoc(), tdesc, offsets, op.getL1HintAttr(), op.getL2HintAttr(),
+          op.getL3HintAttr());
+    }
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
 /// This pattern transforms the UpdateNdOffsetOp to update the offsets of a
 /// subgroup descriptor. It creates an UpdateNdOffsetOp op to update the
 /// offsets of the new subgroup src tensor descriptors.
@@ -690,12 +889,13 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
 namespace mlir {
 namespace xegpu {
 void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
-  patterns.add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
-               WgToSgStoreNdOp, WgToSgUpdateNdOffsetOp, WgToSgDpasOp,
-               WgToSgPrefetchNdOp, UnrealizedConversionCastOpPattern,
-               WgToSgElementwiseOp, WgToSgVectorBroadcastOp,
-               WgToSgConvertLayoutOp, WgToSgArithConstantOp>(
-      patterns.getContext());
+  patterns
+      .add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
+           WgToSgLoadNdOpWithOffset, WgToSgStoreNdOp, WgToSgStoreNdOpWithOffset,
+           WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
+           WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
+           WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
+           WgToSgArithConstantOp>(patterns.getContext());
 }
 } // namespace xegpu
 } // namespace mlir

diff  --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index b6f44b5bc0b68..6ff7a94d678a3 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -10,5 +10,76 @@ gpu.module @test_distribution {
       %tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
         -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
       gpu.return
-    }
+  }
+
+  // CHECK-LABEL: load_nd_tdesc_with_offset
+  gpu.func @load_nd_tdesc_with_offset(%src: memref<256x128xf32>) {
+    // CHECK-COUNT-4: xegpu.load_nd {{%.*}}[{{%.*}}, {{%.*}}]
+    // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    // CHECK-SAME-COUNT-4: -> vector<16x16xf32>
+    // CHECK-NOT: xegpu.load_nd
+    %tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
+      -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+    %load =  xegpu.load_nd %tdesc[0, 0]
+      : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+      -> vector<256x128xf32>
+    gpu.return
+  }
+
+  // CHECK-LABEL: store_nd_with_offset
+  gpu.func @store_nd_with_offset(%src: memref<256x128xf32>) {
+    // CHECK-COUNT-4: xegpu.store_nd %{{.*}}, {{%.*}}[{{%.*}}, {{%.*}}]
+    // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    // CHECK-NOT: xegpu.store_nd
+    %tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
+      -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+    %load =  xegpu.load_nd %tdesc[0, 0]
+      : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+      -> vector<256x128xf32>
+    xegpu.store_nd %load, %tdesc[0, 0]
+      : vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+    gpu.return
+  }
+
+  // CHECK-LABEL: prefetch_nd_tdesc_with_offset
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+  gpu.func @prefetch_nd_tdesc_with_offset(%src: memref<256x128xf32>) {
+    // CHECK-COUNT-4: xegpu.prefetch_nd {{%.*}}[{{%.*}}, {{%.*}}]
+    // CHECK-SAME-COUNT-4: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    // CHECK-NOT: xegpu.prefetch_nd
+    %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
+      -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+    xegpu.prefetch_nd %tdesc[0, 0]
+      : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+    gpu.return
+  }
+
+  // CHECK-LABEL: dpas
+  // CHECK-SAME: (%[[ARG_0:.*]]: memref<256x128xf16>, %[[ARG_1:.*]]: memref<128x256xf16>)
+  gpu.func @dpas(%a: memref<256x128xf16>, %b: memref<128x256xf16>) {
+    // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]] : memref<256x128xf16>
+    // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    // CHECK-NOT: xegpu.create_nd_tdesc
+    // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]] : memref<128x256xf16>
+    // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [4, 8], lane_data = [1, 1]>>
+    // CHECK-NOT: xegpu.create_nd_tdesc
+    // CHECK-COUNT-16: xegpu.dpas %{{.*}}, %{{.*}}
+    // CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+    // CHECK-SAME-COUNT-16: : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
+    // CHECK-NOT: xegpu.dpas
+    %tdesc_a = xegpu.create_nd_tdesc %a : memref<256x128xf16>
+      -> !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+    %load_a =  xegpu.load_nd %tdesc_a[0, 0]
+      : !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+      -> vector<256x128xf16>
+    %tdesc_b = xegpu.create_nd_tdesc %b : memref<128x256xf16>
+      -> !xegpu.tensor_desc<128x256xf16, #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
+    %load_b =  xegpu.load_nd %tdesc_b[0, 0]
+      : !xegpu.tensor_desc<128x256xf16, #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
+      -> vector<128x256xf16>
+    %dpas = xegpu.dpas %load_a, %load_b
+      {layout_result_0 =  #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>}
+      : vector<256x128xf16>, vector<128x256xf16> -> vector<256x256xf32>
+    gpu.return
+  }
 }

diff  --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 025d48e22307e..07a0b86223c33 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -1,5 +1,7 @@
 // RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
 
+//CHECK: #map = affine_map<()[s0] -> (s0 floordiv 4)>
+//CHECK: #map1 = affine_map<()[s0] -> (s0 mod 4)>
 gpu.module @test_distribution {
   // CHECK-LABEL: create_nd_tdesc_no_offset
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
@@ -21,4 +23,244 @@ gpu.module @test_distribution {
         -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
       gpu.return
   }
+
+  // CHECK-LABEL: load_nd_tdesc_with_offset
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+  gpu.func @load_nd_tdesc_with_offset(%src: memref<256x128xf32>) {
+    //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
+    //CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]]
+    //CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]]
+    //CHECK: %[[LOAD:.*]] = xegpu.load_nd {{%.*}}[{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<32x32xf32>
+    %tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
+      -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+    %load =  xegpu.load_nd %tdesc[0, 0]
+      : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+      -> vector<256x128xf32>
+    gpu.return
+  }
+
+  // CHECK-LABEL: store_nd_with_offsets
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+  gpu.func @store_nd_with_offsets(%src: memref<256x128xf32>) {
+    //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
+    //CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]]
+    //CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]]
+    //CHECK: xegpu.store_nd %{{.*}}, {{%.*}}[{{%.*}}, {{%.*}}]  : vector<32x32xf32>, !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
+      -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+    %load =  xegpu.load_nd %tdesc[0, 0]
+      : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+      -> vector<256x128xf32>
+    xegpu.store_nd %load, %tdesc[0, 0]
+      : vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+    gpu.return
+}
+
+  // CHECK-LABEL: prefetch_nd_tdesc_with_offset
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+  gpu.func @prefetch_nd_tdesc_with_offset(%src: memref<256x128xf32>) {
+    //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
+    //CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]]
+    //CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]]
+    //CHECK: xegpu.prefetch_nd %{{.*}}[{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %cst0 = arith.constant 0 : index
+    %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
+      -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+    xegpu.prefetch_nd %tdesc[%cst0, %cst0]
+      : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+    gpu.return
+  }
+
+  // CHECK-LABEL: dpas
+  gpu.func @dpas(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
+    // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x128xf16>, vector<128x16xf16> -> vector<16x16xf32>
+    %tdesc_a = xegpu.create_nd_tdesc %a : memref<128x128xf16>
+      -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1]>>
+    %load_a =  xegpu.load_nd %tdesc_a[0, 0]
+      : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1]>>
+      -> vector<128x128xf16>
+    %tdesc_b = xegpu.create_nd_tdesc %b : memref<128x128xf16>
+      -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
+    %load_b =  xegpu.load_nd %tdesc_b[0, 0]
+      : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
+      -> vector<128x128xf16>
+    %dpas = xegpu.dpas %load_a, %load_b
+      {layout_result_0 =  #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>}
+      : vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32>
+    gpu.return
+  }
+
+  // CHECK-LABEL: dpas_no_sg_data
+  gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
+    // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
+    %tdesc_a = xegpu.create_nd_tdesc %a : memref<128x128xf16>
+      -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1],
+      order = [1, 0]>>
+    %load_a =  xegpu.load_nd %tdesc_a[0, 0]
+      : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1],
+      order = [1, 0]>>
+      -> vector<128x128xf16>
+    %tdesc_b = xegpu.create_nd_tdesc %b : memref<128x128xf16>
+      -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1],
+      order = [1, 0]>>
+    %load_b =  xegpu.load_nd %tdesc_b[0, 0]
+      : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1],
+      order = [1, 0]>>
+      -> vector<128x128xf16>
+    %dpas = xegpu.dpas %load_a, %load_b
+      {layout_result_0 =  #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>}
+      : vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32>
+    gpu.return
+  }
+
+  // CHECK-LABEL: dpas_with_no_create_nd_desc
+  gpu.func @dpas_with_no_create_nd_desc(%a: vector<256x128xf32>, %b: vector<128x256xf32>) {
+    // CHECK-NOT: vector<32x32xf32>
+    %dpas = xegpu.dpas %a, %b
+      {layout =  #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>}
+      : vector<256x128xf32>, vector<128x256xf32> -> vector<256x256xf32>
+    gpu.return
+  }
+
+  // CHECK-LABEL: broadcast_dim1
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<256x1xf32>
+  gpu.func @broadcast_dim1(%src: memref<256x1xf32>) {
+    %tdesc = xegpu.create_nd_tdesc %src : memref<256x1xf32>
+      -> !xegpu.tensor_desc<256x1xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
+    %load =  xegpu.load_nd %tdesc[0, 0]
+      : !xegpu.tensor_desc<256x1xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
+      -> vector<256x1xf32>
+    // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 1]>}
+    // CHECK-SAME: : vector<32x1xf32> to vector<32x32xf32>
+    %broadcast = vector.broadcast %load
+      {layout_result_0 = #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 32], lane_layout = [8, 1], lane_data = [1, 1]>}
+      : vector<256x1xf32> to vector<256x32xf32>
+    gpu.return
+  }
+
+  // CHECK-LABEL: broadcast_dim0
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<1x128xf32>
+  gpu.func @broadcast_dim0(%src: memref<1x128xf32>) {
+    %tdesc = xegpu.create_nd_tdesc %src : memref<1x128xf32>
+      -> !xegpu.tensor_desc<1x128xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+    %load =  xegpu.load_nd %tdesc[0, 0]
+      : !xegpu.tensor_desc<1x128xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+      -> vector<1x128xf32>
+    // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+    // CHECK-SAME: : vector<1x32xf32> to vector<32x32xf32>
+    %broadcast = vector.broadcast %load
+      {layout_result_0 = #xegpu.layout<sg_layout = [1, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>}
+      : vector<1x128xf32> to vector<32x128xf32>
+    gpu.return
+  }
+
+  // CHECK-LABEL: gemm_with_load_store_offset
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<1024x1024xf16>, %[[ARG_1:.*]]: memref<1024x1024xf16>, %[[ARG_2:.*]]: memref<1024x1024xf32>
+  gpu.func @gemm_with_load_store_offset(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) {
+    //CHECK: [[c0:%.+]] = arith.constant 0 : index
+    //CHECK: [[c128:%.+]] = arith.constant 128 : index
+    //CHECK: [[c1024:%.+]] = arith.constant 1024 : index
+    %c0 = arith.constant 0 : index
+    %c128 = arith.constant 128 : index
+    %c1024 = arith.constant 1024 : index
+    %block_id_x = gpu.block_id  x
+    %block_id_y = gpu.block_id  y
+    %0 = arith.muli %block_id_x, %c128 : index
+    %1 = arith.muli %block_id_y, %c128 : index
+    %2 = xegpu.create_nd_tdesc %arg2 : memref<1024x1024xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
+    // CHECK: [[DESC_A:%.+]] = xegpu.create_nd_tdesc %[[ARG_0]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x128xf16>
+    // CHECK: [[DESC_B:%.+]] = xegpu.create_nd_tdesc %[[ARG_1]] : memref<1024x1024xf16> -> !xegpu.tensor_desc<128x16xf16>
+    %3 = xegpu.create_nd_tdesc %arg0 : memref<1024x1024xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>>
+    %4 = xegpu.create_nd_tdesc %arg1 : memref<1024x1024xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>>
+    // load_nd with offset
+    %5 = xegpu.load_nd %2[%0, %1] : !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>> -> vector<128x128xf32>
+    %6 = xegpu.load_nd %3[%0, %c0] : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>> -> vector<128x128xf16>
+    %7 = xegpu.load_nd %4[%c0, %1] : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>> -> vector<128x128xf16>
+    // scf.for loop
+    //      CHECK: [[scf:%.+]]:3 = scf.for [[arg3:%.+]] = [[c0]] to [[c1024]] step [[c128]]
+    // CHECK-SAME: iter_args([[arg4:%.+]] = {{.*}}, [[arg5:%.+]] = {{.*}}, [[arg6:%.+]] = {{.*}}) ->
+    // CHECK-SAME: (vector<16x128xf16>, vector<128x16xf16>, vector<16x16xf32>)
+    //      CHECK: [[c:%.+]] = xegpu.dpas [[arg4]], [[arg5]], [[arg6]] : vector<16x128xf16>, vector<128x16xf16>, vector<16x16xf32> -> vector<16x16xf32>
+    //      CHECK: [[a:%.+]] = xegpu.load_nd [[DESC_A]][{{%.*}}, {{%.*}}]  : !xegpu.tensor_desc<16x128xf16> -> vector<16x128xf16>
+    //      CHECK: [[b:%.+]] = xegpu.load_nd [[DESC_B]][{{%.*}}, {{%.*}}]  : !xegpu.tensor_desc<128x16xf16> -> vector<128x16xf16>
+    //      CHECK: scf.yield [[a]], [[b]], [[c]] : vector<16x128xf16>, vector<128x16xf16>, vector<16x16xf32>
+    %8:3 = scf.for %arg3 = %c0 to %c1024 step %c128 iter_args(%arg4 = %6, %arg5 = %7, %arg6 = %5)
+        -> (vector<128x128xf16>, vector<128x128xf16>, vector<128x128xf32>) {
+      // load_nd with offset inside loop
+      %9 = xegpu.dpas %arg4, %arg5, %arg6 {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>}
+                          : vector<128x128xf16>, vector<128x128xf16>, vector<128x128xf32> -> vector<128x128xf32>
+      %10 = xegpu.load_nd %3[%arg3, %c0] : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>> -> vector<128x128xf16>
+      %11 = xegpu.load_nd %4[%c0, %arg3] : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>> -> vector<128x128xf16>
+      scf.yield %10, %11, %9 : vector<128x128xf16>, vector<128x128xf16>, vector<128x128xf32>
+    }
+    // store_nd with offset
+    xegpu.store_nd %8#2, %2[%0, %1] : vector<128x128xf32>, !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
+    gpu.return
+  }
+
+  // CHECK-LABEL: @subgroup_id_range
+  gpu.func @subgroup_id_range(%src: memref<256x128xf32>, %src1: memref<128x256xf32>, %src2: memref<128x64xf32>) {
+    %sg_id = gpu.subgroup_id : index
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c2 = arith.constant 2 : index
+    %c31 = arith.constant 31 : index
+    %c3 = arith.constant 3 : index
+    %cond1 = arith.cmpi sge, %sg_id, %c0 : index
+    %cond2 = arith.cmpi slt, %sg_id, %c1 : index
+    %cond = arith.andi %cond1, %cond2 : i1
+    scf.if %cond {
+        // CHECK-NOT: index.sub
+        %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
+          -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+        %load =  xegpu.load_nd %tdesc[0, 0]
+          : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+          -> vector<256x128xf32>
+    } {sg_id_range = #xegpu.range<[0, 32]>}
+    %cond3 = arith.cmpi sge, %sg_id, %c2 : index
+    %cond4 = arith.cmpi slt, %sg_id, %c31 : index
+    %cond5 = arith.andi %cond3, %cond4 : i1
+    scf.if %cond5 {
+      // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
+      // CHECK: %[[C2:.*]] = arith.constant 2 : index
+      // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C2]]
+      %tdesc = xegpu.create_nd_tdesc %src2 : memref<128x64xf32>
+        -> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
+      %load =  xegpu.load_nd %tdesc[0, 0]
+        : !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
+        -> vector<128x64xf32>
+      %exp = math.exp %load {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>} : vector<128x64xf32>
+    }{sg_id_range = #xegpu.range<[2, 18]>}
+    gpu.return
+  }
+
+  // CHECK-LABEL: @subgroup_id_range_nested_if
+  gpu.func @subgroup_id_range_nested_if(%src: memref<256x128xf32>, %src1: memref<128x64xf32>) {
+    %sg_id = gpu.subgroup_id : index
+    %c1 = arith.constant 1 : i1
+    %c3 = arith.constant 3 : index
+    %c32 = arith.constant 32 : index
+    %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
+      -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+    %load =  xegpu.load_nd %tdesc[0, 0]
+      : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+      -> vector<256x128xf32>
+    %cond1 = arith.cmpi sge, %sg_id, %c3 : index
+    %cond2 = arith.cmpi slt, %sg_id, %c32 : index
+    %cond = arith.andi %cond1, %cond2 : i1
+    scf.if %c1 {
+      scf.if %cond {
+        // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
+        // CHECK: %[[C3:.*]] = arith.constant 3 : index
+        // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C3]]
+        %td = xegpu.create_nd_tdesc %src1 : memref<128x64xf32>
+          -> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
+        %ld =  xegpu.load_nd %td[0, 0]
+          : !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
+          -> vector<128x64xf32>
+        %exp = math.exp %ld {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>} : vector<128x64xf32>
+    }
+  } {sg_id_range = #xegpu.range<[3, 19]>}
+  gpu.return
+  }
 }


        


More information about the Mlir-commits mailing list