[Mlir-commits] [mlir] [mlir][XeGPU][XeGPUUnroll] Support new syntax with offsets moved to load_nd/store_nd/prefetch_nd (PR #160323)

Dmitry Chigarev llvmlistbot at llvm.org
Tue Sep 23 08:14:44 PDT 2025


https://github.com/dchigarev created https://github.com/llvm/llvm-project/pull/160323

Adds support for new syntax in XeGPUUnroll for:
1. `create_nd_desc` without offsets
2. `load_nd` with offsets
3. `store_nd` with offsets
4. `prefetch_nd` with offsets

`create_nd_desc with offsets` + `load_nd with offsets` won't be lowered correctly. In this case the IR would still have two unrealized conversions that will fail later in the pipeline.

The offsets computation for the unrolled tile is now moved from descriptors to load/store/prefetch operations. The resulted IR now has one single descriptor that is being iterated in load/store/prefetch ops.

<details><summary>old/new behavior examples</summary>

```mlir
// before unroll pass:
gpu.func @load_nd(%src: memref<256x318xf32>) -> vector<24x32xf32> {
  %tdesc = xegpu.create_nd_tdesc %src : memref<256x318xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
  %ld = xegpu.load_nd %tdesc[8, 16]: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32>
  gpu.return %ld : vector<24x32xf32>
}

// after unroll pass (offsets in create_nd_desc):
gpu.func @create_nd_tdesc2(%arg0: memref<256x318xf32>) -> vector<24x32xf32> {
  %cst = arith.constant dense<0.000000e+00> : vector<24x32xf32>
  %c24 = arith.constant 24 : index
  %c32 = arith.constant 32 : index
  %c8 = arith.constant 8 : index
  %c16 = arith.constant 16 : index
  // create 6 descriptors for each tile
  %0 = xegpu.create_nd_tdesc %arg0[%c8, %c16] : memref<256x318xf32> -> !xegpu.tensor_desc<8x16xf32>
  %1 = xegpu.create_nd_tdesc %arg0[%c8, %c32] : memref<256x318xf32> -> !xegpu.tensor_desc<8x16xf32>
  %2 = xegpu.create_nd_tdesc %arg0[%c16, %c16] : memref<256x318xf32> -> !xegpu.tensor_desc<8x16xf32>
  %3 = xegpu.create_nd_tdesc %arg0[%c16, %c32] : memref<256x318xf32> -> !xegpu.tensor_desc<8x16xf32>
  %4 = xegpu.create_nd_tdesc %arg0[%c24, %c16] : memref<256x318xf32> -> !xegpu.tensor_desc<8x16xf32>
  %5 = xegpu.create_nd_tdesc %arg0[%c24, %c32] : memref<256x318xf32> -> !xegpu.tensor_desc<8x16xf32>
  %6 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
  %7 = xegpu.load_nd %1  : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
  %8 = xegpu.load_nd %2  : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
  %9 = xegpu.load_nd %3  : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
  %10 = xegpu.load_nd %4  : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
  %11 = xegpu.load_nd %5  : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
  ...
}

// after unroll pass (offsets in load_nd):
gpu.func @load_nd(%arg0: memref<256x318xf32>) -> vector<24x32xf32> {
  %cst = arith.constant dense<0.000000e+00> : vector<24x32xf32>
  %c24 = arith.constant 24 : index
  %c32 = arith.constant 32 : index
  %c16 = arith.constant 16 : index
  %c8 = arith.constant 8 : index
  // create only one descriptor with proper tile shape
  %0 = xegpu.create_nd_tdesc %arg0 : memref<256x318xf32> -> !xegpu.tensor_desc<8x16xf32>
  // compute tile offsets at the operation (using only one descriptor)
  %1 = xegpu.load_nd %0[%c8, %c16]  : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
  %2 = xegpu.load_nd %0[%c8, %c32]  : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
  %3 = xegpu.load_nd %0[%c16, %c16]  : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
  %4 = xegpu.load_nd %0[%c16, %c32]  : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
  %5 = xegpu.load_nd %0[%c24, %c16]  : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
  %6 = xegpu.load_nd %0[%c24, %c32]  : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
  ...
}
```

</details>





>From f7eee8847ebe967c170361eae93429f9ee339451 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Tue, 23 Sep 2025 14:37:58 +0000
Subject: [PATCH 1/2] [mlir][XeGPU][XeGPUUnroll] Support new syntax with
 offsets moved to load_nd/store_nd/prefetch_nd

Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp  | 177 +++++++++++++-----
 ...xegpu-unroll-patterns-no-desc-offsets.mlir |  61 ++++++
 2 files changed, 186 insertions(+), 52 deletions(-)
 create mode 100644 mlir/test/Dialect/XeGPU/xegpu-unroll-patterns-no-desc-offsets.mlir

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 29c9fcdfebcdb..cad7436f23762 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -121,54 +121,81 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
   xegpu::UnrollOptions options;
 };
 
+// Generic helper function for unrolling operations with offsets.
+//
+// Iterates over tile offsets within the tensor descriptor shape and calls
+// the provided createOp function for each computed offset. This is used by
+// operations like LoadNd, StoreNd, CreateNdDesc, and PrefetchNd when they
+// have explicit offsets that need to be adjusted for each unrolled tile.
+SmallVector<Value> computeUnrolledOffsets(
+    SmallVector<OpFoldResult> mixedOffsets, xegpu::TensorDescType tdescTy,
+    ArrayRef<int64_t> targetShape,
+    const std::function<Value(SmallVector<OpFoldResult>)> &createOp,
+    Location loc, PatternRewriter &rewriter) {
+  int64_t rank = tdescTy.getRank();
+  ArrayRef<int64_t> shape = tdescTy.getShape();
+
+  auto addi = [&](OpFoldResult a, int64_t b) -> Value {
+    std::optional<int64_t> maybeInt = getConstantIntValue(a);
+    if (maybeInt) {
+      return arith::ConstantIndexOp::create(rewriter, loc, *maybeInt + b);
+    } else {
+      auto aV = llvm::cast<Value>(a);
+      auto bV = arith::ConstantIndexOp::create(rewriter, loc, b);
+      return rewriter.createOrFold<arith::AddIOp>(loc, aV, bV);
+    }
+  };
+
+  SmallVector<OpFoldResult> oldOffsets = llvm::to_vector(
+      llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank));
+  auto validIdxes =
+      llvm::seq<int64_t>(mixedOffsets.size() - rank, mixedOffsets.size());
+
+  SmallVector<Value> newOps;
+  for (SmallVector<int64_t> offsets :
+       StaticTileOffsetRange(shape, targetShape)) {
+
+    for (auto [idx, oldOff, offset] :
+         llvm::zip(validIdxes, oldOffsets, offsets))
+      mixedOffsets[idx] = addi(oldOff, offset);
+
+    auto newOp = createOp(mixedOffsets);
+    newOps.push_back(newOp);
+  }
+  return newOps;
+}
+
 struct UnrollCreateNdOp : public UnrollPattern<xegpu::CreateNdDescOp> {
   using UnrollPattern<xegpu::CreateNdDescOp>::UnrollPattern;
   LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op,
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
     xegpu::TensorDescType tdescTy = op.getType();
-    int64_t rank = tdescTy.getRank();
-    ArrayRef<int64_t> shape = tdescTy.getShape();
 
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
     if (!targetShape)
       return failure();
 
-    auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
-
-    auto addi = [&](OpFoldResult a, int64_t b) -> Value {
-      std::optional<int64_t> maybeInt = getConstantIntValue(a);
-      if (maybeInt) {
-        return arith::ConstantIndexOp::create(rewriter, loc, *maybeInt + b);
-      } else {
-        auto aV = llvm::cast<Value>(a);
-        auto bV = arith::ConstantIndexOp::create(rewriter, loc, b);
-        return rewriter.createOrFold<arith::AddIOp>(loc, aV, bV);
-      }
-    };
-
-    SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
-
-    // For n-D memrefs where n > rank, we need to handle the last `rank`
-    // dimensions only, and keep the first `n-rank` dimensions as is.
-    SmallVector<OpFoldResult> oldOffsets = llvm::to_vector(
-        llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank));
-    auto validIdxes =
-        llvm::seq<int64_t>(mixedOffsets.size() - rank, mixedOffsets.size());
-
     SmallVector<Value> newOps;
-    for (SmallVector<int64_t> offsets :
-         StaticTileOffsetRange(shape, *targetShape)) {
-
-      for (auto [idx, oldOff, offset] :
-           llvm::zip(validIdxes, oldOffsets, offsets))
-        mixedOffsets[idx] = addi(oldOff, offset);
 
+    auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
+    bool hasOffsets = op.getMixedOffsets().size() != 0;
+    if (!hasOffsets) {
       auto newOp = xegpu::CreateNdDescOp::create(
-          rewriter, loc, newTdescTy, op.getSource(), mixedOffsets,
-          op.getMixedSizes(), op.getMixedStrides());
+          rewriter, loc, newTdescTy, op.getSource(), op.getMixedSizes(),
+          op.getMixedStrides());
       newOps.push_back(newOp);
+    } else {
+      auto createOp = [&](SmallVector<OpFoldResult> offsets) -> Value {
+        return xegpu::CreateNdDescOp::create(
+            rewriter, loc, newTdescTy, op.getSource(), offsets,
+            op.getMixedSizes(), op.getMixedStrides());
+      };
+
+      newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy,
+                                      *targetShape, createOp, loc, rewriter);
     }
+
     Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
     rewriter.replaceOp(op, castOp);
 
@@ -216,17 +243,33 @@ struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
       return failure();
 
     int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
-    if ((offsetSize != 0) || op.getConstOffsetsAttr())
-      return failure();
+    bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
 
     SmallVector<Type> convertedTdescTypes =
         getUnrolledTypes(tdescTy, *targetShape);
+
+    if (hasOffsets)
+      convertedTdescTypes.resize(1);
+
     SmallVector<Value> convertedTdesc = pack(
         op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
 
-    for (auto t : convertedTdesc)
-      xegpu::PrefetchNdOp::create(rewriter, loc, TypeRange(), t,
-                                  op->getAttrs());
+    if (!hasOffsets) {
+      for (auto t : convertedTdesc)
+        xegpu::PrefetchNdOp::create(rewriter, loc, TypeRange(), t,
+                                    op->getAttrs());
+    } else {
+      auto createPrefetch = [&](SmallVector<OpFoldResult> offsets) -> Value {
+        xegpu::PrefetchNdOp::create(rewriter, loc, convertedTdesc[0], offsets,
+                                    op.getL1HintAttr(), op.getL2HintAttr(),
+                                    op.getL3HintAttr());
+        // return dummy Value to satisfy function's signature
+        return nullptr;
+      };
+
+      computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape,
+                             createPrefetch, loc, rewriter);
+    }
 
     rewriter.eraseOp(op);
     return success();
@@ -247,26 +290,39 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
       return failure();
 
     int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
-    if ((offsetSize != 0) || op.getConstOffsetsAttr())
-      return failure();
+    bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
 
     Type elemTy = tdescTy.getElementType();
     VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
 
     SmallVector<Type> convertedTdescTypes =
         getUnrolledTypes(tdescTy, *targetShape);
+
+    if (hasOffsets)
+      convertedTdescTypes.resize(1);
+
     SmallVector<Value> convertedTdescs = pack(
         op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
-
     SmallVector<Value> newOps;
-    for (auto t : convertedTdescs) {
-      auto newOp =
-          xegpu::LoadNdOp::create(rewriter, loc, newValueTy, t, op->getAttrs());
-      newOps.push_back(newOp);
+
+    if (!hasOffsets) {
+      for (auto t : convertedTdescs) {
+        auto newOp = xegpu::LoadNdOp::create(rewriter, loc, newValueTy, t,
+                                             op->getAttrs());
+        newOps.push_back(newOp);
+      }
+    } else {
+      auto createLoad = [&](SmallVector<OpFoldResult> offsets) {
+        return xegpu::LoadNdOp::create(
+            rewriter, loc, newValueTy, convertedTdescs[0], offsets,
+            op.getPackedAttr(), op.getTransposeAttr(), op.getL1HintAttr(),
+            op.getL2HintAttr(), op.getL3HintAttr());
+      };
+      newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy,
+                                      *targetShape, createLoad, loc, rewriter);
     }
 
     Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
-
     rewriter.replaceOp(op, castOp);
     return success();
   }
@@ -285,22 +341,39 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
       return failure();
 
     int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
-    if ((offsetSize != 0) || op.getConstOffsetsAttr())
-      return failure();
+    bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
 
     SmallVector<Type> convertedValTypes =
         getUnrolledTypes(valueTy, *targetShape);
     SmallVector<Type> convertedTdescTypes =
         getUnrolledTypes(tdescTy, *targetShape);
 
-    SmallVector<Value> convertedValues =
-        pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
+    if (hasOffsets)
+      convertedTdescTypes.resize(1);
+
     SmallVector<Value> convertedTdescs = pack(
         op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
 
-    for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs))
-      xegpu::StoreNdOp::create(rewriter, loc, v, t, op.getL1HintAttr(),
-                               op.getL2HintAttr(), op.getL3HintAttr());
+    SmallVector<Value> convertedValues =
+        pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
+    if (!hasOffsets) {
+      for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs))
+        xegpu::StoreNdOp::create(rewriter, loc, v, t, op.getL1HintAttr(),
+                                 op.getL2HintAttr(), op.getL3HintAttr());
+    } else {
+      size_t valueIndex = 0;
+      auto createStore = [&](SmallVector<OpFoldResult> offsets) {
+        xegpu::StoreNdOp::create(rewriter, loc, convertedValues[valueIndex++],
+                                 convertedTdescs[0], offsets,
+                                 op.getL1HintAttr(), op.getL2HintAttr(),
+                                 op.getL3HintAttr());
+        // return dummy Value to satisfy function's signature
+        return nullptr;
+      };
+
+      computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape,
+                             createStore, loc, rewriter);
+    }
 
     rewriter.eraseOp(op);
     return success();
diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns-no-desc-offsets.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns-no-desc-offsets.mlir
new file mode 100644
index 0000000000000..f28e82a2a4c76
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns-no-desc-offsets.mlir
@@ -0,0 +1,61 @@
+// RUN: mlir-opt --test-xegpu-unrolling-patterns -split-input-file %s | FileCheck %s
+
+gpu.module @xevm_test {
+
+  // CHECK-LABEL: create_nd_tdesc
+  // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
+  // CHECK-COUNT-1: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+  // CHECK: [[cast:%.+]] = builtin.unrealized_conversion_cast
+  // CHECK-SAME: !xegpu.tensor_desc<8x16xf32>
+  // CHECK-SAME: to !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {__xegpu_blocking_tile_shape__ = array<i64: 8, 16>, __xegpu_blocking_unpack__}
+  gpu.func @create_nd_tdesc(%src: memref<24x32xf32>) -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {
+    %tdesc = xegpu.create_nd_tdesc %src : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+    gpu.return %tdesc : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+  }
+
+//-----
+  // CHECK-LABEL: load_nd
+  // CHECK-SAME: [[arg0:%.+]]: memref<256x318xf32>
+  // CHECK-COUNT-1: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<256x318xf32> -> !xegpu.tensor_desc<8x16xf32>
+  // CHECK-COUNT-6: [[ld:%.+]] = xegpu.load_nd {{.*}}[{{.*}}]  : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
+  // CHECK-COUNT-6: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<24x32xf32>
+  gpu.func @load_nd(%src: memref<256x318xf32>) -> vector<24x32xf32> {
+    %tdesc = xegpu.create_nd_tdesc %src : memref<256x318xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+    %ld = xegpu.load_nd %tdesc[8, 16]: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32>
+    gpu.return %ld : vector<24x32xf32>
+  }
+
+//-----
+  // CHECK-LABEL: load_nd_store_nd
+  // CHECK-SAME: [[arg0:%.+]]: memref<256x318xf32>
+  //CHECK-COUNT-1: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<256x318xf32> -> !xegpu.tensor_desc<8x16xf32>
+  //CHECK-COUNT-6: [[data:%.+]] = xegpu.load_nd {{.*}}[{{.*}}]  : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
+  //CHECK-COUNT-6: xegpu.store_nd {{.*}}[{{.*}}] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  gpu.func @load_nd_store_nd(%src: memref<256x318xf32>) {
+    %tdesc = xegpu.create_nd_tdesc %src : memref<256x318xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+    %ld = xegpu.load_nd %tdesc[8, 16]: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32>
+    xegpu.store_nd %ld, %tdesc[0, 0] : vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+    gpu.return
+  }
+
+//-----
+  // CHECK-LABEL: prefetch_nd_tdesc
+  // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
+  // CHECK-COUNT-1: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+  // CHECK-COUNT-6: xegpu.prefetch_nd {{.*}}[{{.*}}] : !xegpu.tensor_desc<8x16xf32>
+  gpu.func @prefetch_nd_tdesc(%src: memref<24x32xf32>) {
+    %tdesc = xegpu.create_nd_tdesc %src : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+    xegpu.prefetch_nd %tdesc[8, 16] : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+    gpu.return
+  }
+
+//-----
+
+  // CHECK-LABEL: load_nd_offsets_at_both_places
+  // CHECK-COUNT-2: builtin.unrealized_conversion_cast
+  gpu.func @load_nd_offsets_at_both_places(%src: memref<256x318xf32>) -> vector<24x32xf32> {
+    %tdesc = xegpu.create_nd_tdesc %src[16, 8] : memref<256x318xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+    %ld = xegpu.load_nd %tdesc[8, 16]: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32>
+    gpu.return %ld : vector<24x32xf32>
+  }
+}
\ No newline at end of file

>From f45f04735e0db6a662154a15b28e82682b2c6d86 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Tue, 23 Sep 2025 14:50:24 +0000
Subject: [PATCH 2/2] fix formatting

Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp    | 17 +++++++++++++----
 .../xegpu-unroll-patterns-no-desc-offsets.mlir  |  2 +-
 2 files changed, 14 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index cad7436f23762..80d1cb12dff80 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -195,7 +195,6 @@ struct UnrollCreateNdOp : public UnrollPattern<xegpu::CreateNdDescOp> {
       newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy,
                                       *targetShape, createOp, loc, rewriter);
     }
-
     Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
     rewriter.replaceOp(op, castOp);
 
@@ -248,8 +247,11 @@ struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
     SmallVector<Type> convertedTdescTypes =
         getUnrolledTypes(tdescTy, *targetShape);
 
-    if (hasOffsets)
+    if (hasOffsets) {
+      // only need one tdesc, tile offsets will be computed
+      // at the operation level
       convertedTdescTypes.resize(1);
+    }
 
     SmallVector<Value> convertedTdesc = pack(
         op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
@@ -298,8 +300,11 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
     SmallVector<Type> convertedTdescTypes =
         getUnrolledTypes(tdescTy, *targetShape);
 
-    if (hasOffsets)
+    if (hasOffsets) {
+      // only need one tdesc, tile offsets will be computed
+      // at the operation level
       convertedTdescTypes.resize(1);
+    }
 
     SmallVector<Value> convertedTdescs = pack(
         op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
@@ -323,6 +328,7 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
     }
 
     Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
+
     rewriter.replaceOp(op, castOp);
     return success();
   }
@@ -348,8 +354,11 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
     SmallVector<Type> convertedTdescTypes =
         getUnrolledTypes(tdescTy, *targetShape);
 
-    if (hasOffsets)
+    if (hasOffsets) {
+      // only need one tdesc, tile offsets will be computed
+      // at the operation level
       convertedTdescTypes.resize(1);
+    }
 
     SmallVector<Value> convertedTdescs = pack(
         op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns-no-desc-offsets.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns-no-desc-offsets.mlir
index f28e82a2a4c76..cbfd991b5557e 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns-no-desc-offsets.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns-no-desc-offsets.mlir
@@ -58,4 +58,4 @@ gpu.module @xevm_test {
     %ld = xegpu.load_nd %tdesc[8, 16]: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32>
     gpu.return %ld : vector<24x32xf32>
   }
-}
\ No newline at end of file
+}



More information about the Mlir-commits mailing list