[Mlir-commits] [mlir] [MLIR][XeGPU]Enhance Pack/Unpack for XeGPUUnroll (PR #163459)

Nishant Patel llvmlistbot at llvm.org
Fri Oct 24 10:26:04 PDT 2025


https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/163459

>From 23f9557209d0661040c5ce54f2af3a72cf17b712 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 14 Oct 2025 21:39:46 +0000
Subject: [PATCH 1/6] Enhance Pack/Unpack for XeGPUUnroll

---
 .../XeGPU/Transforms/XeGPUBlocking.cpp        | 11 +++--
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp  |  6 +--
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   | 27 ++++++++++--
 mlir/test/Dialect/XeGPU/xegpu-blocking.mlir   | 43 +++++++++++++++++++
 4 files changed, 75 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index f77784abaf0b2..48831728ad624 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -145,8 +145,13 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
   xegpu::DistributeLayoutAttr layout =
       xegpu::getDistributeLayoutAttr(operandOrResult);
   if (layout && layout.isForSubgroup()) {
-    if (!layout.getEffectiveInstDataAsInt().empty())
-      return layout.getEffectiveInstDataAsInt();
+    if (!layout.getEffectiveInstDataAsInt().empty()) {
+      SmallVector<int64_t> instData = layout.getEffectiveInstDataAsInt();
+      // Remove all leading unit dimensions from inst_data
+      while (!instData.empty() && instData.front() == 1)
+        instData.erase(instData.begin());
+      return instData;
+    }
 
     if (auto type = dyn_cast<ShapedType>(value.getType()))
       return llvm::to_vector(type.getShape());
@@ -363,7 +368,7 @@ void XeGPUBlockingPass::runOnOperation() {
           xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding,
                                      tdescTy.getLayoutAttr().dropInstData());
     } else {
-      newTy = type.clone(tileShape, elemTy);
+      newTy = VectorType::get(tileShape, elemTy);
     }
 
     if (returnSingleType)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index a178d0fe4b0b0..75b215c320e54 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -66,8 +66,6 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
   Value unpack(ValueRange srcs, Type destTy, ArrayRef<int64_t> blockSize,
                Location loc, PatternRewriter &rewriter) const {
     if (auto vecTy = dyn_cast<VectorType>(destTy)) {
-      assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
-             "Expecting blockSize size to match the rank of destTy.");
       auto shape = vecTy.getShape();
       return xegpu::createVectorWithShapeFromValues(rewriter, loc, srcs, shape);
     }
@@ -93,8 +91,6 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
                           ArrayRef<int64_t> blockSize, Location loc,
                           PatternRewriter &rewriter) const {
     if (auto vecTy = dyn_cast<VectorType>(src.getType())) {
-      assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
-             "Expecting blockSize size to match the rank of src.");
       return xegpu::extractVectorsWithShapeFromValue(rewriter, loc, src,
                                                      blockSize);
     }
@@ -635,7 +631,7 @@ struct UnrollLoadGatherOpWithOffset
     VectorType maskTy = llvm::dyn_cast<VectorType>(mask.getType());
     VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.getType());
     Type elemTy = valueTy.getElementType();
-    VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
+    VectorType newValueTy = VectorType::get(*targetShape, elemTy);
 
     SmallVector<Type> convertedMaskTypes;
     SmallVector<Value> convertedMasks;
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 2c56a438ea62c..40013eb161678 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -246,11 +246,30 @@ xegpu::extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc,
   if (!computeShapeRatio(srcShape, shape))
     return {value};
 
+  int64_t srcShapeRank = srcShape.size();
+  int64_t targetShapeRank = shape.size();
+
+  SmallVector<int64_t> adjustedTargetShape(srcShape.size());
+  int64_t rankDiff = srcShapeRank - targetShapeRank;
+  std::fill(adjustedTargetShape.begin(), adjustedTargetShape.begin() + rankDiff,
+            1);
+  std::copy(shape.begin(), shape.end(), adjustedTargetShape.begin() + rankDiff);
+
+  int64_t adjustedTargetShapeRank = adjustedTargetShape.size();
+
   SmallVector<Value> result;
-  for (SmallVector<int64_t> offsets : StaticTileOffsetRange(srcShape, shape)) {
+  for (SmallVector<int64_t> offsets :
+       StaticTileOffsetRange(srcShape, adjustedTargetShape)) {
     SmallVector<int64_t> staticStrides(offsets.size(), 1);
-    result.push_back(vector::ExtractStridedSliceOp::create(
-        builder, loc, value, offsets, shape, staticStrides));
+    Value slice = vector::ExtractStridedSliceOp::create(
+        builder, loc, value, offsets, adjustedTargetShape, staticStrides);
+
+    // Reshape to remove leading unit dims if needed
+    if (adjustedTargetShapeRank > targetShapeRank) {
+      auto targetTy = VectorType::get(shape, vecTy.getElementType());
+      slice = builder.create<vector::ShapeCastOp>(loc, targetTy, slice);
+    }
+    result.push_back(slice);
   }
 
   return result;
@@ -274,7 +293,7 @@ Value xegpu::createVectorWithShapeFromValues(OpBuilder &builder, Location loc,
 
   for (auto [src, offsets] :
        llvm::zip_equal(values, StaticTileOffsetRange(shape, tileShape))) {
-    SmallVector<int64_t> staticStrides(offsets.size(), 1);
+    SmallVector<int64_t> staticStrides(tileShape.size(), 1);
     result = vector::InsertStridedSliceOp::create(builder, loc, src, result,
                                                   offsets, staticStrides);
   }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
index fe4f44c0b02ab..6301533da640d 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
@@ -682,3 +682,46 @@ gpu.module @test_kernel {
     gpu.return
   }
 }
+
+// -----
+gpu.module @test_kernel {
+  // CHECK-LABEL: load_gather
+  // CHECK-COUNT-2: xegpu.load  {{.*}}[{{.*}}], {{.*}} <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : ui64, vector<16xindex>, vector<16xi1> -> vector<16xf32>
+  gpu.func @load_gather(%src: ui64) -> vector<1x1x32xf32> {
+      %cst = arith.constant dense<[[
+      [0,   8,  16,  24,  32,  40,  48,  56,
+      64,  72,  80,  88,  96, 104, 112, 120,
+      128, 136, 144, 152, 160, 168, 176, 184,
+      192, 200, 208, 216, 224, 232, 240, 248]
+      ]]> : vector<1x1x32xindex>
+
+      %mask = arith.constant dense<true> : vector<1x1x32xi1>
+      %ld = xegpu.load %src[%cst], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<inst_data = [1, 1, 16]>, l1_hint = #xegpu.cache_hint<cached>} : ui64, vector<1x1x32xindex>, vector<1x1x32xi1> -> vector<1x1x32xf32>
+
+      gpu.return %ld : vector<1x1x32xf32>
+  }
+}
+
+// -----
+gpu.module @test_kernel {
+  // CHECK-LABEL: store_scatter
+  // CHECK-COUNT-2: xegpu.store  {{.*}}[{{.*}}], {{.*}} <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<16xf32>, ui64, vector<16xindex>, vector<16xi1>
+  gpu.func @store_scatter(%src: ui64) {
+      %cst = arith.constant dense<[[
+      [0,   8,  16,  24,  32,  40,  48,  56,
+      64,  72,  80,  88,  96, 104, 112, 120,
+      128, 136, 144, 152, 160, 168, 176, 184,
+      192, 200, 208, 216, 224, 232, 240, 248]
+      ]]> : vector<1x1x32xindex>
+
+      %mask = arith.constant dense<true> : vector<1x1x32xi1>
+
+      %st_vec = arith.constant dense<1023.0>: vector<1x1x32xf32>
+      xegpu.store %st_vec, %src[%cst], %mask {chunk_size = 1, layout_operand_0 = #xegpu.layout<inst_data = [1, 1, 16]>,
+                                              layout_operand_2 = #xegpu.layout<inst_data = [1, 1, 16]>,
+                                              layout_operand_3 = #xegpu.layout<inst_data = [1, 1, 16]>,
+                                              l1_hint = #xegpu.cache_hint<cached>} : vector<1x1x32xf32>, ui64, vector<1x1x32xindex>, vector<1x1x32xi1>
+
+      gpu.return
+  }
+}

>From 19278f9545074308d8fe1baf8210eecce69dee83 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 15 Oct 2025 18:24:27 +0000
Subject: [PATCH 2/6] Address comments

---
 mlir/test/Dialect/XeGPU/xegpu-blocking.mlir | 24 ++++++++++++++++-----
 1 file changed, 19 insertions(+), 5 deletions(-)

diff --git a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
index 6301533da640d..57af76aead1d3 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
@@ -686,7 +686,15 @@ gpu.module @test_kernel {
 // -----
 gpu.module @test_kernel {
   // CHECK-LABEL: load_gather
-  // CHECK-COUNT-2: xegpu.load  {{.*}}[{{.*}}], {{.*}} <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : ui64, vector<16xindex>, vector<16xi1> -> vector<16xf32>
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK: [[cst:%.+]] = arith.constant dense<0.000000e+00> : vector<1x1x32xf32>
+  // CHECK: [[cst_0:%.+]] = arith.constant dense<true> : vector<16xi1>
+  // CHECK: [[cst_1:%.+]] = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex>
+  // CHECK: [[cst_2:%.+]] = arith.constant dense<[128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248]> : vector<16xindex>
+  // CHECK: [[ld_0:%.+]] = xegpu.load [[arg0]][[[cst_1]]], [[cst_0]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : ui64, vector<16xindex>, vector<16xi1> -> vector<16xf32>
+  // CHECK: [[ld_1:%.+]] = xegpu.load [[arg0]][[[cst_2]]], [[cst_0]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : ui64, vector<16xindex>, vector<16xi1> -> vector<16xf32>
+  // CHECK: [[ins_0:%.+]] = vector.insert_strided_slice [[ld_0]], [[cst]] {offsets = [0, 0, 0], strides = [1]} : vector<16xf32> into vector<1x1x32xf32>
+  // CHECK: [[ins_1:%.+]] = vector.insert_strided_slice [[ld_1]], [[ins_0]] {offsets = [0, 0, 16], strides = [1]} : vector<16xf32> into vector<1x1x32xf32>
   gpu.func @load_gather(%src: ui64) -> vector<1x1x32xf32> {
       %cst = arith.constant dense<[[
       [0,   8,  16,  24,  32,  40,  48,  56,
@@ -705,18 +713,24 @@ gpu.module @test_kernel {
 // -----
 gpu.module @test_kernel {
   // CHECK-LABEL: store_scatter
-  // CHECK-COUNT-2: xegpu.store  {{.*}}[{{.*}}], {{.*}} <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<16xf32>, ui64, vector<16xindex>, vector<16xi1>
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-DAG: [[cst:%.+]] = arith.constant dense<true> : vector<16xi1>
+  // CHECK-DAG: [[cst_0:%.+]] = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex>
+  // CHECK-DAG: [[cst_1:%.+]] = arith.constant dense<[128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248]> : vector<16xindex>
+  // CHECK-DAG: [[cst_2:%.+]] = arith.constant dense<1.023000e+03> : vector<16xf32>
+  // CHECK: xegpu.store [[cst_2]], [[arg0]][[[cst_0]]], [[cst]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<16xf32>, ui64, vector<16xindex>, vector<16xi1>
+  // CHECK: xegpu.store [[cst_2]], [[arg0]][[[cst_1]]], [[cst]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<16xf32>, ui64, vector<16xindex>, vector<16xi1>
   gpu.func @store_scatter(%src: ui64) {
-      %cst = arith.constant dense<[[
+      %cst = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 1, 16]>} dense<[[
       [0,   8,  16,  24,  32,  40,  48,  56,
       64,  72,  80,  88,  96, 104, 112, 120,
       128, 136, 144, 152, 160, 168, 176, 184,
       192, 200, 208, 216, 224, 232, 240, 248]
       ]]> : vector<1x1x32xindex>
 
-      %mask = arith.constant dense<true> : vector<1x1x32xi1>
+      %mask = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 1, 16]>} dense<true>  : vector<1x1x32xi1>
 
-      %st_vec = arith.constant dense<1023.0>: vector<1x1x32xf32>
+      %st_vec = arith.constant dense<1023.0> : vector<1x1x32xf32>
       xegpu.store %st_vec, %src[%cst], %mask {chunk_size = 1, layout_operand_0 = #xegpu.layout<inst_data = [1, 1, 16]>,
                                               layout_operand_2 = #xegpu.layout<inst_data = [1, 1, 16]>,
                                               layout_operand_3 = #xegpu.layout<inst_data = [1, 1, 16]>,

>From 7019948fd3f3581a2623e7b9c06276ab095817de Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 16 Oct 2025 03:07:42 +0000
Subject: [PATCH 3/6] Skip xegpu nd ops for trimming leading unit dims

---
 .../XeGPU/Transforms/XeGPUBlocking.cpp        | 15 +++--
 mlir/test/Dialect/XeGPU/xegpu-blocking.mlir   | 59 +++++++++++--------
 2 files changed, 47 insertions(+), 27 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 48831728ad624..10b69e0a29aa5 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -147,9 +147,17 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
   if (layout && layout.isForSubgroup()) {
     if (!layout.getEffectiveInstDataAsInt().empty()) {
       SmallVector<int64_t> instData = layout.getEffectiveInstDataAsInt();
-      // Remove all leading unit dimensions from inst_data
-      while (!instData.empty() && instData.front() == 1)
-        instData.erase(instData.begin());
+      // Remove leading unit dimensions from inst_data
+      // Skip it for xegpu nd ops since it will be 2D
+      Operation *definingOp = value.getDefiningOp();
+      bool skipLeadingUnitDimRemoval =
+          definingOp &&
+          (isa<xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::DpasOp,
+               xegpu::StoreNdOp, xegpu::PrefetchNdOp>(definingOp));
+      if (!skipLeadingUnitDimRemoval) {
+        while (!instData.empty() && instData.front() == 1)
+          instData.erase(instData.begin());
+      }
       return instData;
     }
 
@@ -359,7 +367,6 @@ void XeGPUBlockingPass::runOnOperation() {
           // To create a new attribute with a different chunk_size:
           auto newEncoding = xegpu::ScatterTensorDescAttr::get(
               ctx, tdescTy.getMemorySpace(), blockedChunkSize);
-
           encoding = newEncoding;
         }
       }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
index 57af76aead1d3..f8eccf54d4fa1 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
@@ -711,31 +711,44 @@ gpu.module @test_kernel {
 }
 
 // -----
+#l = #xegpu.layout<inst_data = [1, 16]>
 gpu.module @test_kernel {
-  // CHECK-LABEL: store_scatter
-  // CHECK-SAME: [[arg0:%.+]]: ui64
-  // CHECK-DAG: [[cst:%.+]] = arith.constant dense<true> : vector<16xi1>
-  // CHECK-DAG: [[cst_0:%.+]] = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex>
-  // CHECK-DAG: [[cst_1:%.+]] = arith.constant dense<[128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248]> : vector<16xindex>
-  // CHECK-DAG: [[cst_2:%.+]] = arith.constant dense<1.023000e+03> : vector<16xf32>
-  // CHECK: xegpu.store [[cst_2]], [[arg0]][[[cst_0]]], [[cst]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<16xf32>, ui64, vector<16xindex>, vector<16xi1>
-  // CHECK: xegpu.store [[cst_2]], [[arg0]][[[cst_1]]], [[cst]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<16xf32>, ui64, vector<16xindex>, vector<16xi1>
-  gpu.func @store_scatter(%src: ui64) {
-      %cst = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 1, 16]>} dense<[[
-      [0,   8,  16,  24,  32,  40,  48,  56,
-      64,  72,  80,  88,  96, 104, 112, 120,
-      128, 136, 144, 152, 160, 168, 176, 184,
-      192, 200, 208, 216, 224, 232, 240, 248]
-      ]]> : vector<1x1x32xindex>
+  // CHECK-LABEL: load_store_nd_with_offsets
+  // CHECK-SAME: [[arg0:%.+]]: memref<1024x1024xf32>, [[arg1:%.+]]: memref<1024x1024xf32>, [[arg2:%.+]]: memref<1024x1024xf32>
+  // CHECK-DAG: [[cst:%.+]] = arith.constant dense<0.000000e+00> : vector<1x32xf32>
+  // CHECK-DAG: [[c16:%.+]] = arith.constant 16 : index
+  // CHECK-DAG: [[c0:%.+]] = arith.constant 0 : index
+  // CHECK: [[tdesc_a:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<1024x1024xf32> -> !xegpu.tensor_desc<1x16xf32>
+  // CHECK: [[tdesc_b:%.+]] = xegpu.create_nd_tdesc [[arg1]] : memref<1024x1024xf32> -> !xegpu.tensor_desc<1x16xf32>
+  // CHECK: [[tdesc_c:%.+]] = xegpu.create_nd_tdesc [[arg2]] : memref<1024x1024xf32> -> !xegpu.tensor_desc<1x16xf32>
+  // CHECK: [[ld_a0:%.+]] = xegpu.load_nd [[tdesc_a]][[[c0]], [[c0]]]  : !xegpu.tensor_desc<1x16xf32> -> vector<1x16xf32>
+  // CHECK: [[ld_a1:%.+]] = xegpu.load_nd [[tdesc_a]][[[c0]], [[c16]]]  : !xegpu.tensor_desc<1x16xf32> -> vector<1x16xf32>
+  // CHECK: [[ld_b0:%.+]] = xegpu.load_nd [[tdesc_b]][[[c0]], [[c0]]]  : !xegpu.tensor_desc<1x16xf32> -> vector<1x16xf32>
+  // CHECK: [[ld_b1:%.+]] = xegpu.load_nd [[tdesc_b]][[[c0]], [[c16]]]  : !xegpu.tensor_desc<1x16xf32> -> vector<1x16xf32>
+  // CHECK: [[cast_a0:%.+]] = vector.shape_cast [[ld_a0]] : vector<1x16xf32> to vector<16xf32>
+  // CHECK: [[cast_b0:%.+]] = vector.shape_cast [[ld_b0]] : vector<1x16xf32> to vector<16xf32>
+  // CHECK: [[add0:%.+]] = arith.addf [[cast_a0]], [[cast_b0]] : vector<16xf32>
+  // CHECK: [[ins0:%.+]] = vector.insert_strided_slice [[add0]], [[cst]] {offsets = [0, 0], strides = [1]} : vector<16xf32> into vector<1x32xf32>
+  // CHECK: [[cast_a1:%.+]] = vector.shape_cast [[ld_a1]] : vector<1x16xf32> to vector<16xf32>
+  // CHECK: [[cast_b1:%.+]] = vector.shape_cast [[ld_b1]] : vector<1x16xf32> to vector<16xf32>
+  // CHECK: [[add1:%.+]] = arith.addf [[cast_a1]], [[cast_b1]] : vector<16xf32>
+  // CHECK: [[ins1:%.+]] = vector.insert_strided_slice [[add1]], [[ins0]] {offsets = [0, 16], strides = [1]} : vector<16xf32> into vector<1x32xf32>
+  // CHECK: [[ext0:%.+]] = vector.extract_strided_slice [[ins1]] {offsets = [0, 0], sizes = [1, 16], strides = [1, 1]} : vector<1x32xf32> to vector<1x16xf32>
+  // CHECK: [[ext1:%.+]] = vector.extract_strided_slice [[ins1]] {offsets = [0, 16], sizes = [1, 16], strides = [1, 1]} : vector<1x32xf32> to vector<1x16xf32>
+  // CHECK: xegpu.store_nd [[ext0]], [[tdesc_c]][[[c0]], [[c0]]]  : vector<1x16xf32>, !xegpu.tensor_desc<1x16xf32>
+  // CHECK: xegpu.store_nd [[ext1]], [[tdesc_c]][[[c0]], [[c16]]]  : vector<1x16xf32>, !xegpu.tensor_desc<1x16xf32>
+  gpu.func @load_store_nd_with_offsets(%A: memref<1024x1024xf32>, %B: memref<1024x1024xf32>, %C: memref<1024x1024xf32>) {
+    %c0 = arith.constant 0 : index
 
-      %mask = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 1, 16]>} dense<true>  : vector<1x1x32xi1>
+    %a_tdesc = xegpu.create_nd_tdesc %A : memref<1024x1024xf32> -> !xegpu.tensor_desc<1x32xf32, #l>
+    %b_tdesc = xegpu.create_nd_tdesc %B : memref<1024x1024xf32> -> !xegpu.tensor_desc<1x32xf32, #l>
+    %c_tdesc = xegpu.create_nd_tdesc %C : memref<1024x1024xf32> -> !xegpu.tensor_desc<1x32xf32, #l>
 
-      %st_vec = arith.constant dense<1023.0> : vector<1x1x32xf32>
-      xegpu.store %st_vec, %src[%cst], %mask {chunk_size = 1, layout_operand_0 = #xegpu.layout<inst_data = [1, 1, 16]>,
-                                              layout_operand_2 = #xegpu.layout<inst_data = [1, 1, 16]>,
-                                              layout_operand_3 = #xegpu.layout<inst_data = [1, 1, 16]>,
-                                              l1_hint = #xegpu.cache_hint<cached>} : vector<1x1x32xf32>, ui64, vector<1x1x32xindex>, vector<1x1x32xi1>
+    %a = xegpu.load_nd %a_tdesc[%c0, %c0] : !xegpu.tensor_desc<1x32xf32, #l> -> vector<1x32xf32>
+    %b = xegpu.load_nd %b_tdesc[%c0, %c0] : !xegpu.tensor_desc<1x32xf32, #l> -> vector<1x32xf32>
 
-      gpu.return
+    %result = arith.addf %a, %b {layout_result_0 = #l} : vector<1x32xf32>
+    xegpu.store_nd %result, %c_tdesc[%c0, %c0] : vector<1x32xf32>, !xegpu.tensor_desc<1x32xf32, #l>
+    gpu.return
   }
-}
+}
\ No newline at end of file

>From a4ebc377678e9d926710def4a8c48fecf5bce2c4 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 16 Oct 2025 16:03:05 +0000
Subject: [PATCH 4/6] Newline

---
 mlir/test/Dialect/XeGPU/xegpu-blocking.mlir | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
index f8eccf54d4fa1..4e533dc55333d 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
@@ -696,14 +696,14 @@ gpu.module @test_kernel {
   // CHECK: [[ins_0:%.+]] = vector.insert_strided_slice [[ld_0]], [[cst]] {offsets = [0, 0, 0], strides = [1]} : vector<16xf32> into vector<1x1x32xf32>
   // CHECK: [[ins_1:%.+]] = vector.insert_strided_slice [[ld_1]], [[ins_0]] {offsets = [0, 0, 16], strides = [1]} : vector<16xf32> into vector<1x1x32xf32>
   gpu.func @load_gather(%src: ui64) -> vector<1x1x32xf32> {
-      %cst = arith.constant dense<[[
+      %cst = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 1, 16]>} dense<[[
       [0,   8,  16,  24,  32,  40,  48,  56,
       64,  72,  80,  88,  96, 104, 112, 120,
       128, 136, 144, 152, 160, 168, 176, 184,
       192, 200, 208, 216, 224, 232, 240, 248]
       ]]> : vector<1x1x32xindex>
 
-      %mask = arith.constant dense<true> : vector<1x1x32xi1>
+      %mask = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 1, 16]>} dense<true> : vector<1x1x32xi1>
       %ld = xegpu.load %src[%cst], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<inst_data = [1, 1, 16]>, l1_hint = #xegpu.cache_hint<cached>} : ui64, vector<1x1x32xindex>, vector<1x1x32xi1> -> vector<1x1x32xf32>
 
       gpu.return %ld : vector<1x1x32xf32>
@@ -751,4 +751,4 @@ gpu.module @test_kernel {
     xegpu.store_nd %result, %c_tdesc[%c0, %c0] : vector<1x32xf32>, !xegpu.tensor_desc<1x32xf32, #l>
     gpu.return
   }
-}
\ No newline at end of file
+}

>From e6a814e8d4bc9dd61dfcfb4e6fe8696f1bfc2956 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 22 Oct 2025 18:01:08 +0000
Subject: [PATCH 5/6] Address comments

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp | 6 ++++--
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp         | 4 +---
 mlir/test/Dialect/XeGPU/xegpu-blocking.mlir         | 4 ++--
 3 files changed, 7 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 10b69e0a29aa5..b519746285868 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -148,6 +148,8 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
     if (!layout.getEffectiveInstDataAsInt().empty()) {
       SmallVector<int64_t> instData = layout.getEffectiveInstDataAsInt();
       // Remove leading unit dimensions from inst_data
+      // For example, if the inst_data is [1, 1, 32]
+      // it will pass [32] as the unroll/blocking size.
       // Skip it for xegpu nd ops since it will be 2D
       Operation *definingOp = value.getDefiningOp();
       bool skipLeadingUnitDimRemoval =
@@ -155,8 +157,8 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
           (isa<xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::DpasOp,
                xegpu::StoreNdOp, xegpu::PrefetchNdOp>(definingOp));
       if (!skipLeadingUnitDimRemoval) {
-        while (!instData.empty() && instData.front() == 1)
-          instData.erase(instData.begin());
+        auto it = llvm::find_if(instData, [](auto val) { return val != 1; });
+        instData.erase(instData.begin(), it);
       }
       return instData;
     }
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 40013eb161678..6949e07175960 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -255,8 +255,6 @@ xegpu::extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc,
             1);
   std::copy(shape.begin(), shape.end(), adjustedTargetShape.begin() + rankDiff);
 
-  int64_t adjustedTargetShapeRank = adjustedTargetShape.size();
-
   SmallVector<Value> result;
   for (SmallVector<int64_t> offsets :
        StaticTileOffsetRange(srcShape, adjustedTargetShape)) {
@@ -265,7 +263,7 @@ xegpu::extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc,
         builder, loc, value, offsets, adjustedTargetShape, staticStrides);
 
     // Reshape to remove leading unit dims if needed
-    if (adjustedTargetShapeRank > targetShapeRank) {
+    if (srcShapeRank > targetShapeRank) {
       auto targetTy = VectorType::get(shape, vecTy.getElementType());
       slice = builder.create<vector::ShapeCastOp>(loc, targetTy, slice);
     }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
index 4e533dc55333d..7e742af754fbe 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
@@ -685,7 +685,7 @@ gpu.module @test_kernel {
 
 // -----
 gpu.module @test_kernel {
-  // CHECK-LABEL: load_gather
+  // CHECK-LABEL: remove_unit_dim_inst_data
   // CHECK-SAME: [[arg0:%.+]]: ui64
   // CHECK: [[cst:%.+]] = arith.constant dense<0.000000e+00> : vector<1x1x32xf32>
   // CHECK: [[cst_0:%.+]] = arith.constant dense<true> : vector<16xi1>
@@ -695,7 +695,7 @@ gpu.module @test_kernel {
   // CHECK: [[ld_1:%.+]] = xegpu.load [[arg0]][[[cst_2]]], [[cst_0]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : ui64, vector<16xindex>, vector<16xi1> -> vector<16xf32>
   // CHECK: [[ins_0:%.+]] = vector.insert_strided_slice [[ld_0]], [[cst]] {offsets = [0, 0, 0], strides = [1]} : vector<16xf32> into vector<1x1x32xf32>
   // CHECK: [[ins_1:%.+]] = vector.insert_strided_slice [[ld_1]], [[ins_0]] {offsets = [0, 0, 16], strides = [1]} : vector<16xf32> into vector<1x1x32xf32>
-  gpu.func @load_gather(%src: ui64) -> vector<1x1x32xf32> {
+  gpu.func @remove_unit_dim_inst_data(%src: ui64) -> vector<1x1x32xf32> {
       %cst = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 1, 16]>} dense<[[
       [0,   8,  16,  24,  32,  40,  48,  56,
       64,  72,  80,  88,  96, 104, 112, 120,

>From 717b8b4fd1fab563da8eb4386cae1944fb510978 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Fri, 24 Oct 2025 17:25:28 +0000
Subject: [PATCH 6/6] Add TODO

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index b519746285868..2c37140ad9c76 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -151,6 +151,9 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
       // For example, if the inst_data is [1, 1, 32]
       // it will pass [32] as the unroll/blocking size.
       // Skip it for xegpu nd ops since it will be 2D
+      // TODO: For vectors ops, experiment with the
+      // upstream vector remove leading unit dims patterns,
+      // populateCastAwayVectorLeadingOneDimPatterns.
       Operation *definingOp = value.getDefiningOp();
       bool skipLeadingUnitDimRemoval =
           definingOp &&



More information about the Mlir-commits mailing list