[Mlir-commits] [mlir] [mlir][XeGPU][XeGPUUnroll] Support new syntax with offsets moved to load_nd/store_nd/prefetch_nd (PR #160323)

Dmitry Chigarev llvmlistbot at llvm.org
Wed Sep 24 03:09:50 PDT 2025


https://github.com/dchigarev updated https://github.com/llvm/llvm-project/pull/160323

>From f7eee8847ebe967c170361eae93429f9ee339451 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Tue, 23 Sep 2025 14:37:58 +0000
Subject: [PATCH 1/3] [mlir][XeGPU][XeGPUUnroll] Support new syntax with
 offsets moved to load_nd/store_nd/prefetch_nd

Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp  | 177 +++++++++++++-----
 ...xegpu-unroll-patterns-no-desc-offsets.mlir |  61 ++++++
 2 files changed, 186 insertions(+), 52 deletions(-)
 create mode 100644 mlir/test/Dialect/XeGPU/xegpu-unroll-patterns-no-desc-offsets.mlir

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 29c9fcdfebcdb..cad7436f23762 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -121,54 +121,81 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
   xegpu::UnrollOptions options;
 };
 
+// Generic helper function for unrolling operations with offsets.
+//
+// Iterates over tile offsets within the tensor descriptor shape and calls
+// the provided createOp function for each computed offset. This is used by
+// operations like LoadNd, StoreNd, CreateNdDesc, and PrefetchNd when they
+// have explicit offsets that need to be adjusted for each unrolled tile.
+SmallVector<Value> computeUnrolledOffsets(
+    SmallVector<OpFoldResult> mixedOffsets, xegpu::TensorDescType tdescTy,
+    ArrayRef<int64_t> targetShape,
+    const std::function<Value(SmallVector<OpFoldResult>)> &createOp,
+    Location loc, PatternRewriter &rewriter) {
+  int64_t rank = tdescTy.getRank();
+  ArrayRef<int64_t> shape = tdescTy.getShape();
+
+  auto addi = [&](OpFoldResult a, int64_t b) -> Value {
+    std::optional<int64_t> maybeInt = getConstantIntValue(a);
+    if (maybeInt) {
+      return arith::ConstantIndexOp::create(rewriter, loc, *maybeInt + b);
+    } else {
+      auto aV = llvm::cast<Value>(a);
+      auto bV = arith::ConstantIndexOp::create(rewriter, loc, b);
+      return rewriter.createOrFold<arith::AddIOp>(loc, aV, bV);
+    }
+  };
+
+  SmallVector<OpFoldResult> oldOffsets = llvm::to_vector(
+      llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank));
+  auto validIdxes =
+      llvm::seq<int64_t>(mixedOffsets.size() - rank, mixedOffsets.size());
+
+  SmallVector<Value> newOps;
+  for (SmallVector<int64_t> offsets :
+       StaticTileOffsetRange(shape, targetShape)) {
+
+    for (auto [idx, oldOff, offset] :
+         llvm::zip(validIdxes, oldOffsets, offsets))
+      mixedOffsets[idx] = addi(oldOff, offset);
+
+    auto newOp = createOp(mixedOffsets);
+    newOps.push_back(newOp);
+  }
+  return newOps;
+}
+
 struct UnrollCreateNdOp : public UnrollPattern<xegpu::CreateNdDescOp> {
   using UnrollPattern<xegpu::CreateNdDescOp>::UnrollPattern;
   LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op,
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
     xegpu::TensorDescType tdescTy = op.getType();
-    int64_t rank = tdescTy.getRank();
-    ArrayRef<int64_t> shape = tdescTy.getShape();
 
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
     if (!targetShape)
       return failure();
 
-    auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
-
-    auto addi = [&](OpFoldResult a, int64_t b) -> Value {
-      std::optional<int64_t> maybeInt = getConstantIntValue(a);
-      if (maybeInt) {
-        return arith::ConstantIndexOp::create(rewriter, loc, *maybeInt + b);
-      } else {
-        auto aV = llvm::cast<Value>(a);
-        auto bV = arith::ConstantIndexOp::create(rewriter, loc, b);
-        return rewriter.createOrFold<arith::AddIOp>(loc, aV, bV);
-      }
-    };
-
-    SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
-
-    // For n-D memrefs where n > rank, we need to handle the last `rank`
-    // dimensions only, and keep the first `n-rank` dimensions as is.
-    SmallVector<OpFoldResult> oldOffsets = llvm::to_vector(
-        llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank));
-    auto validIdxes =
-        llvm::seq<int64_t>(mixedOffsets.size() - rank, mixedOffsets.size());
-
     SmallVector<Value> newOps;
-    for (SmallVector<int64_t> offsets :
-         StaticTileOffsetRange(shape, *targetShape)) {
-
-      for (auto [idx, oldOff, offset] :
-           llvm::zip(validIdxes, oldOffsets, offsets))
-        mixedOffsets[idx] = addi(oldOff, offset);
 
+    auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
+    bool hasOffsets = op.getMixedOffsets().size() != 0;
+    if (!hasOffsets) {
       auto newOp = xegpu::CreateNdDescOp::create(
-          rewriter, loc, newTdescTy, op.getSource(), mixedOffsets,
-          op.getMixedSizes(), op.getMixedStrides());
+          rewriter, loc, newTdescTy, op.getSource(), op.getMixedSizes(),
+          op.getMixedStrides());
       newOps.push_back(newOp);
+    } else {
+      auto createOp = [&](SmallVector<OpFoldResult> offsets) -> Value {
+        return xegpu::CreateNdDescOp::create(
+            rewriter, loc, newTdescTy, op.getSource(), offsets,
+            op.getMixedSizes(), op.getMixedStrides());
+      };
+
+      newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy,
+                                      *targetShape, createOp, loc, rewriter);
     }
+
     Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
     rewriter.replaceOp(op, castOp);
 
@@ -216,17 +243,33 @@ struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
       return failure();
 
     int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
-    if ((offsetSize != 0) || op.getConstOffsetsAttr())
-      return failure();
+    bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
 
     SmallVector<Type> convertedTdescTypes =
         getUnrolledTypes(tdescTy, *targetShape);
+
+    if (hasOffsets)
+      convertedTdescTypes.resize(1);
+
     SmallVector<Value> convertedTdesc = pack(
         op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
 
-    for (auto t : convertedTdesc)
-      xegpu::PrefetchNdOp::create(rewriter, loc, TypeRange(), t,
-                                  op->getAttrs());
+    if (!hasOffsets) {
+      for (auto t : convertedTdesc)
+        xegpu::PrefetchNdOp::create(rewriter, loc, TypeRange(), t,
+                                    op->getAttrs());
+    } else {
+      auto createPrefetch = [&](SmallVector<OpFoldResult> offsets) -> Value {
+        xegpu::PrefetchNdOp::create(rewriter, loc, convertedTdesc[0], offsets,
+                                    op.getL1HintAttr(), op.getL2HintAttr(),
+                                    op.getL3HintAttr());
+        // return dummy Value to satisfy function's signature
+        return nullptr;
+      };
+
+      computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape,
+                             createPrefetch, loc, rewriter);
+    }
 
     rewriter.eraseOp(op);
     return success();
@@ -247,26 +290,39 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
       return failure();
 
     int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
-    if ((offsetSize != 0) || op.getConstOffsetsAttr())
-      return failure();
+    bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
 
     Type elemTy = tdescTy.getElementType();
     VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
 
     SmallVector<Type> convertedTdescTypes =
         getUnrolledTypes(tdescTy, *targetShape);
+
+    if (hasOffsets)
+      convertedTdescTypes.resize(1);
+
     SmallVector<Value> convertedTdescs = pack(
         op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
-
     SmallVector<Value> newOps;
-    for (auto t : convertedTdescs) {
-      auto newOp =
-          xegpu::LoadNdOp::create(rewriter, loc, newValueTy, t, op->getAttrs());
-      newOps.push_back(newOp);
+
+    if (!hasOffsets) {
+      for (auto t : convertedTdescs) {
+        auto newOp = xegpu::LoadNdOp::create(rewriter, loc, newValueTy, t,
+                                             op->getAttrs());
+        newOps.push_back(newOp);
+      }
+    } else {
+      auto createLoad = [&](SmallVector<OpFoldResult> offsets) {
+        return xegpu::LoadNdOp::create(
+            rewriter, loc, newValueTy, convertedTdescs[0], offsets,
+            op.getPackedAttr(), op.getTransposeAttr(), op.getL1HintAttr(),
+            op.getL2HintAttr(), op.getL3HintAttr());
+      };
+      newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy,
+                                      *targetShape, createLoad, loc, rewriter);
     }
 
     Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
-
     rewriter.replaceOp(op, castOp);
     return success();
   }
@@ -285,22 +341,39 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
       return failure();
 
     int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
-    if ((offsetSize != 0) || op.getConstOffsetsAttr())
-      return failure();
+    bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
 
     SmallVector<Type> convertedValTypes =
         getUnrolledTypes(valueTy, *targetShape);
     SmallVector<Type> convertedTdescTypes =
         getUnrolledTypes(tdescTy, *targetShape);
 
-    SmallVector<Value> convertedValues =
-        pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
+    if (hasOffsets)
+      convertedTdescTypes.resize(1);
+
     SmallVector<Value> convertedTdescs = pack(
         op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
 
-    for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs))
-      xegpu::StoreNdOp::create(rewriter, loc, v, t, op.getL1HintAttr(),
-                               op.getL2HintAttr(), op.getL3HintAttr());
+    SmallVector<Value> convertedValues =
+        pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
+    if (!hasOffsets) {
+      for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs))
+        xegpu::StoreNdOp::create(rewriter, loc, v, t, op.getL1HintAttr(),
+                                 op.getL2HintAttr(), op.getL3HintAttr());
+    } else {
+      size_t valueIndex = 0;
+      auto createStore = [&](SmallVector<OpFoldResult> offsets) {
+        xegpu::StoreNdOp::create(rewriter, loc, convertedValues[valueIndex++],
+                                 convertedTdescs[0], offsets,
+                                 op.getL1HintAttr(), op.getL2HintAttr(),
+                                 op.getL3HintAttr());
+        // return dummy Value to satisfy function's signature
+        return nullptr;
+      };
+
+      computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape,
+                             createStore, loc, rewriter);
+    }
 
     rewriter.eraseOp(op);
     return success();
diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns-no-desc-offsets.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns-no-desc-offsets.mlir
new file mode 100644
index 0000000000000..f28e82a2a4c76
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns-no-desc-offsets.mlir
@@ -0,0 +1,61 @@
+// RUN: mlir-opt --test-xegpu-unrolling-patterns -split-input-file %s | FileCheck %s
+
+gpu.module @xevm_test {
+
+  // CHECK-LABEL: create_nd_tdesc
+  // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
+  // CHECK-COUNT-1: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+  // CHECK: [[cast:%.+]] = builtin.unrealized_conversion_cast
+  // CHECK-SAME: !xegpu.tensor_desc<8x16xf32>
+  // CHECK-SAME: to !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {__xegpu_blocking_tile_shape__ = array<i64: 8, 16>, __xegpu_blocking_unpack__}
+  gpu.func @create_nd_tdesc(%src: memref<24x32xf32>) -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {
+    %tdesc = xegpu.create_nd_tdesc %src : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+    gpu.return %tdesc : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+  }
+
+//-----
+  // CHECK-LABEL: load_nd
+  // CHECK-SAME: [[arg0:%.+]]: memref<256x318xf32>
+  // CHECK-COUNT-1: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<256x318xf32> -> !xegpu.tensor_desc<8x16xf32>
+  // CHECK-COUNT-6: [[ld:%.+]] = xegpu.load_nd {{.*}}[{{.*}}]  : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
+  // CHECK-COUNT-6: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<24x32xf32>
+  gpu.func @load_nd(%src: memref<256x318xf32>) -> vector<24x32xf32> {
+    %tdesc = xegpu.create_nd_tdesc %src : memref<256x318xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+    %ld = xegpu.load_nd %tdesc[8, 16]: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32>
+    gpu.return %ld : vector<24x32xf32>
+  }
+
+//-----
+  // CHECK-LABEL: load_nd_store_nd
+  // CHECK-SAME: [[arg0:%.+]]: memref<256x318xf32>
+  //CHECK-COUNT-1: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<256x318xf32> -> !xegpu.tensor_desc<8x16xf32>
+  //CHECK-COUNT-6: [[data:%.+]] = xegpu.load_nd {{.*}}[{{.*}}]  : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
+  //CHECK-COUNT-6: xegpu.store_nd {{.*}}[{{.*}}] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  gpu.func @load_nd_store_nd(%src: memref<256x318xf32>) {
+    %tdesc = xegpu.create_nd_tdesc %src : memref<256x318xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+    %ld = xegpu.load_nd %tdesc[8, 16]: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32>
+    xegpu.store_nd %ld, %tdesc[0, 0] : vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+    gpu.return
+  }
+
+//-----
+  // CHECK-LABEL: prefetch_nd_tdesc
+  // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
+  // CHECK-COUNT-1: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+  // CHECK-COUNT-6: xegpu.prefetch_nd {{.*}}[{{.*}}] : !xegpu.tensor_desc<8x16xf32>
+  gpu.func @prefetch_nd_tdesc(%src: memref<24x32xf32>) {
+    %tdesc = xegpu.create_nd_tdesc %src : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+    xegpu.prefetch_nd %tdesc[8, 16] : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+    gpu.return
+  }
+
+//-----
+
+  // CHECK-LABEL: load_nd_offsets_at_both_places
+  // CHECK-COUNT-2: builtin.unrealized_conversion_cast
+  gpu.func @load_nd_offsets_at_both_places(%src: memref<256x318xf32>) -> vector<24x32xf32> {
+    %tdesc = xegpu.create_nd_tdesc %src[16, 8] : memref<256x318xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+    %ld = xegpu.load_nd %tdesc[8, 16]: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32>
+    gpu.return %ld : vector<24x32xf32>
+  }
+}
\ No newline at end of file

>From f45f04735e0db6a662154a15b28e82682b2c6d86 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Tue, 23 Sep 2025 14:50:24 +0000
Subject: [PATCH 2/3] fix formatting

Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp    | 17 +++++++++++++----
 .../xegpu-unroll-patterns-no-desc-offsets.mlir  |  2 +-
 2 files changed, 14 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index cad7436f23762..80d1cb12dff80 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -195,7 +195,6 @@ struct UnrollCreateNdOp : public UnrollPattern<xegpu::CreateNdDescOp> {
       newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy,
                                       *targetShape, createOp, loc, rewriter);
     }
-
     Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
     rewriter.replaceOp(op, castOp);
 
@@ -248,8 +247,11 @@ struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
     SmallVector<Type> convertedTdescTypes =
         getUnrolledTypes(tdescTy, *targetShape);
 
-    if (hasOffsets)
+    if (hasOffsets) {
+      // only need one tdesc, tile offsets will be computed
+      // at the operation level
       convertedTdescTypes.resize(1);
+    }
 
     SmallVector<Value> convertedTdesc = pack(
         op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
@@ -298,8 +300,11 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
     SmallVector<Type> convertedTdescTypes =
         getUnrolledTypes(tdescTy, *targetShape);
 
-    if (hasOffsets)
+    if (hasOffsets) {
+      // only need one tdesc, tile offsets will be computed
+      // at the operation level
       convertedTdescTypes.resize(1);
+    }
 
     SmallVector<Value> convertedTdescs = pack(
         op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
@@ -323,6 +328,7 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
     }
 
     Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
+
     rewriter.replaceOp(op, castOp);
     return success();
   }
@@ -348,8 +354,11 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
     SmallVector<Type> convertedTdescTypes =
         getUnrolledTypes(tdescTy, *targetShape);
 
-    if (hasOffsets)
+    if (hasOffsets) {
+      // only need one tdesc, tile offsets will be computed
+      // at the operation level
       convertedTdescTypes.resize(1);
+    }
 
     SmallVector<Value> convertedTdescs = pack(
         op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns-no-desc-offsets.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns-no-desc-offsets.mlir
index f28e82a2a4c76..cbfd991b5557e 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns-no-desc-offsets.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns-no-desc-offsets.mlir
@@ -58,4 +58,4 @@ gpu.module @xevm_test {
     %ld = xegpu.load_nd %tdesc[8, 16]: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32>
     gpu.return %ld : vector<24x32xf32>
   }
-}
\ No newline at end of file
+}

>From 932346e77594318552ae60d64473dda041192888 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Wed, 24 Sep 2025 10:09:36 +0000
Subject: [PATCH 3/3] Modify 'unrolledTypefFn' to return one single type

Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
 .../Dialect/XeGPU/Transforms/Transforms.h     |  8 +++--
 .../XeGPU/Transforms/XeGPUBlocking.cpp        |  5 ++-
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp  | 35 +++++--------------
 .../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp |  5 ++-
 4 files changed, 22 insertions(+), 31 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
index 44b81796b1313..b74c15e5b7ac1 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
@@ -9,9 +9,9 @@
 #ifndef MLIR_DIALECT_XEGPU_TRANSFORMS_TRANSFORMS_H
 #define MLIR_DIALECT_XEGPU_TRANSFORMS_TRANSFORMS_H
 
+#include "mlir/IR/Operation.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/LogicalResult.h"
-#include "mlir/IR/Operation.h"
 
 #include <functional>
 #include <optional>
@@ -47,9 +47,11 @@ struct UnrollOptions {
 
   /// Function that converts a ShapedType (TensorDescType or VectorType)
   /// into the unrolled type based on the tileShape. It returns a vector of
-  /// types representing the unrolled types for simplicity.
+  /// types representing the unrolled types for simplicity. When
+  /// `returnSingleType` is true, it returns a vector containing only one single
+  /// unrolled type.
   using UnrolledTypeFnType = std::function<SmallVector<Type>(
-      ShapedType type, ArrayRef<int64_t> tileShape)>;
+      ShapedType type, ArrayRef<int64_t> tileShape, bool returnSingleType)>;
   UnrolledTypeFnType getUnrolledTypes = nullptr;
   UnrollOptions &setUnrolledTypesFn(UnrolledTypeFnType fn) {
     getUnrolledTypes = std::move(fn);
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 7efa4b9fbd934..36c498e8b849d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -319,7 +319,8 @@ void XeGPUBlockingPass::runOnOperation() {
 
   options.setNativeShapeFn([&](Operation *op) { return getTileShape(op); });
 
-  options.setUnrolledTypesFn([&](ShapedType type, ArrayRef<int64_t> tileShape) {
+  options.setUnrolledTypesFn([&](ShapedType type, ArrayRef<int64_t> tileShape,
+                                 bool returnSingleType = false) {
     Type elemTy = type.getElementType();
     Type newTy;
 
@@ -352,6 +353,8 @@ void XeGPUBlockingPass::runOnOperation() {
       newTy = type.clone(tileShape, elemTy);
     }
 
+    if (returnSingleType)
+      return SmallVector<Type>{newTy};
     std::optional<SmallVector<int64_t>> ratio =
         computeShapeRatio(type.getShape(), tileShape);
     assert(ratio && "The shape of the type must be a multiple of tileShape.");
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 80d1cb12dff80..f738effe46a72 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -56,8 +56,9 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
   }
 
   SmallVector<Type> getUnrolledTypes(ShapedType type,
-                                     ArrayRef<int64_t> tileShape) const {
-    return options.getUnrolledTypes(type, tileShape);
+                                     ArrayRef<int64_t> tileShape,
+                                     bool returnSingleType = false) const {
+    return options.getUnrolledTypes(type, tileShape, returnSingleType);
   }
 
   /// Emulate the the unpack behavior using insert_strided_slice for VectorType
@@ -244,14 +245,8 @@ struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
     int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
     bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
 
-    SmallVector<Type> convertedTdescTypes =
-        getUnrolledTypes(tdescTy, *targetShape);
-
-    if (hasOffsets) {
-      // only need one tdesc, tile offsets will be computed
-      // at the operation level
-      convertedTdescTypes.resize(1);
-    }
+    SmallVector<Type> convertedTdescTypes = getUnrolledTypes(
+        tdescTy, *targetShape, /*returnSingleType*/ hasOffsets);
 
     SmallVector<Value> convertedTdesc = pack(
         op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
@@ -297,14 +292,8 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
     Type elemTy = tdescTy.getElementType();
     VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
 
-    SmallVector<Type> convertedTdescTypes =
-        getUnrolledTypes(tdescTy, *targetShape);
-
-    if (hasOffsets) {
-      // only need one tdesc, tile offsets will be computed
-      // at the operation level
-      convertedTdescTypes.resize(1);
-    }
+    SmallVector<Type> convertedTdescTypes = getUnrolledTypes(
+        tdescTy, *targetShape, /*returnSingleType*/ hasOffsets);
 
     SmallVector<Value> convertedTdescs = pack(
         op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
@@ -351,14 +340,8 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
 
     SmallVector<Type> convertedValTypes =
         getUnrolledTypes(valueTy, *targetShape);
-    SmallVector<Type> convertedTdescTypes =
-        getUnrolledTypes(tdescTy, *targetShape);
-
-    if (hasOffsets) {
-      // only need one tdesc, tile offsets will be computed
-      // at the operation level
-      convertedTdescTypes.resize(1);
-    }
+    SmallVector<Type> convertedTdescTypes = getUnrolledTypes(
+        tdescTy, *targetShape, /*returnSingleType*/ hasOffsets);
 
     SmallVector<Value> convertedTdescs = pack(
         op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index e1ba45c60ac36..b2bdf3efc65f7 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -95,7 +95,8 @@ struct TestXeGPUUnrollingPatterns
         });
 
     options.setUnrolledTypesFn(
-        [&](ShapedType type, ArrayRef<int64_t> tileShape) -> SmallVector<Type> {
+        [&](ShapedType type, ArrayRef<int64_t> tileShape,
+            bool returnSingleType = false) -> SmallVector<Type> {
           Type elemTy = type.getElementType();
           Type newTy;
 
@@ -137,6 +138,8 @@ struct TestXeGPUUnrollingPatterns
             newTy = type.clone(tileShape, elemTy);
           }
 
+          if (returnSingleType)
+            return SmallVector<Type>{newTy};
           std::optional<SmallVector<int64_t>> ratio =
               computeShapeRatio(type.getShape(), tileShape);
           assert(ratio && "Expecting the ratio to be valid.");



More information about the Mlir-commits mailing list