[Mlir-commits] [mlir] 621ed04 - [MLIR][XeGPU]Enhance Pack/Unpack	for XeGPUUnroll (#163459)
    llvmlistbot at llvm.org 
    llvmlistbot at llvm.org
       
    Fri Oct 24 11:04:10 PDT 2025
    
    
  
Author: Nishant Patel
Date: 2025-10-24T11:04:05-07:00
New Revision: 621ed04e28787ade92b98e296332ac71d1b81678
URL: https://github.com/llvm/llvm-project/commit/621ed04e28787ade92b98e296332ac71d1b81678
DIFF: https://github.com/llvm/llvm-project/commit/621ed04e28787ade92b98e296332ac71d1b81678.diff
LOG: [MLIR][XeGPU]Enhance Pack/Unpack for XeGPUUnroll (#163459)
This PR changes the pack/unpack method used for unrolling to allow for
lower rank slice to be extracted and inserted from and to src vector by
adding reshapes. It also removes leading unit dims from inst_data if
there are any.
Added: 
    
Modified: 
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
    mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
    mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
Removed: 
    
################################################################################
diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index f77784abaf0b2..2c37140ad9c76 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -145,8 +145,26 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
   xegpu::DistributeLayoutAttr layout =
       xegpu::getDistributeLayoutAttr(operandOrResult);
   if (layout && layout.isForSubgroup()) {
-    if (!layout.getEffectiveInstDataAsInt().empty())
-      return layout.getEffectiveInstDataAsInt();
+    if (!layout.getEffectiveInstDataAsInt().empty()) {
+      SmallVector<int64_t> instData = layout.getEffectiveInstDataAsInt();
+      // Remove leading unit dimensions from inst_data
+      // For example, if the inst_data is [1, 1, 32]
+      // it will pass [32] as the unroll/blocking size.
+      // Skip it for xegpu nd ops since it will be 2D
+      // TODO: For vectors ops, experiment with the
+      // upstream vector remove leading unit dims patterns,
+      // populateCastAwayVectorLeadingOneDimPatterns.
+      Operation *definingOp = value.getDefiningOp();
+      bool skipLeadingUnitDimRemoval =
+          definingOp &&
+          (isa<xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::DpasOp,
+               xegpu::StoreNdOp, xegpu::PrefetchNdOp>(definingOp));
+      if (!skipLeadingUnitDimRemoval) {
+        auto it = llvm::find_if(instData, [](auto val) { return val != 1; });
+        instData.erase(instData.begin(), it);
+      }
+      return instData;
+    }
 
     if (auto type = dyn_cast<ShapedType>(value.getType()))
       return llvm::to_vector(type.getShape());
@@ -354,7 +372,6 @@ void XeGPUBlockingPass::runOnOperation() {
           // To create a new attribute with a 
diff erent chunk_size:
           auto newEncoding = xegpu::ScatterTensorDescAttr::get(
               ctx, tdescTy.getMemorySpace(), blockedChunkSize);
-
           encoding = newEncoding;
         }
       }
@@ -363,7 +380,7 @@ void XeGPUBlockingPass::runOnOperation() {
           xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding,
                                      tdescTy.getLayoutAttr().dropInstData());
     } else {
-      newTy = type.clone(tileShape, elemTy);
+      newTy = VectorType::get(tileShape, elemTy);
     }
 
     if (returnSingleType)
diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index aafa1b7deb84b..e6e71cc29a80a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -66,8 +66,6 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
   Value unpack(ValueRange srcs, Type destTy, ArrayRef<int64_t> blockSize,
                Location loc, PatternRewriter &rewriter) const {
     if (auto vecTy = dyn_cast<VectorType>(destTy)) {
-      assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
-             "Expecting blockSize size to match the rank of destTy.");
       auto shape = vecTy.getShape();
       return xegpu::createVectorWithShapeFromValues(rewriter, loc, srcs, shape);
     }
@@ -93,8 +91,6 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
                           ArrayRef<int64_t> blockSize, Location loc,
                           PatternRewriter &rewriter) const {
     if (auto vecTy = dyn_cast<VectorType>(src.getType())) {
-      assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
-             "Expecting blockSize size to match the rank of src.");
       return xegpu::extractVectorsWithShapeFromValue(rewriter, loc, src,
                                                      blockSize);
     }
@@ -635,7 +631,7 @@ struct UnrollLoadGatherOpWithOffset
     VectorType maskTy = llvm::dyn_cast<VectorType>(mask.getType());
     VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.getType());
     Type elemTy = valueTy.getElementType();
-    VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
+    VectorType newValueTy = VectorType::get(*targetShape, elemTy);
 
     SmallVector<Type> convertedMaskTypes;
     SmallVector<Value> convertedMasks;
diff  --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 2c56a438ea62c..b4605cd7e94d6 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -246,11 +246,28 @@ xegpu::extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc,
   if (!computeShapeRatio(srcShape, shape))
     return {value};
 
+  int64_t srcShapeRank = srcShape.size();
+  int64_t targetShapeRank = shape.size();
+
+  SmallVector<int64_t> adjustedTargetShape(srcShape.size());
+  int64_t rankDiff = srcShapeRank - targetShapeRank;
+  std::fill(adjustedTargetShape.begin(), adjustedTargetShape.begin() + rankDiff,
+            1);
+  std::copy(shape.begin(), shape.end(), adjustedTargetShape.begin() + rankDiff);
+
   SmallVector<Value> result;
-  for (SmallVector<int64_t> offsets : StaticTileOffsetRange(srcShape, shape)) {
+  for (SmallVector<int64_t> offsets :
+       StaticTileOffsetRange(srcShape, adjustedTargetShape)) {
     SmallVector<int64_t> staticStrides(offsets.size(), 1);
-    result.push_back(vector::ExtractStridedSliceOp::create(
-        builder, loc, value, offsets, shape, staticStrides));
+    Value slice = vector::ExtractStridedSliceOp::create(
+        builder, loc, value, offsets, adjustedTargetShape, staticStrides);
+
+    // Reshape to remove leading unit dims if needed
+    if (srcShapeRank > targetShapeRank) {
+      auto targetTy = VectorType::get(shape, vecTy.getElementType());
+      slice = vector::ShapeCastOp::create(builder, loc, targetTy, slice);
+    }
+    result.push_back(slice);
   }
 
   return result;
@@ -274,7 +291,7 @@ Value xegpu::createVectorWithShapeFromValues(OpBuilder &builder, Location loc,
 
   for (auto [src, offsets] :
        llvm::zip_equal(values, StaticTileOffsetRange(shape, tileShape))) {
-    SmallVector<int64_t> staticStrides(offsets.size(), 1);
+    SmallVector<int64_t> staticStrides(tileShape.size(), 1);
     result = vector::InsertStridedSliceOp::create(builder, loc, src, result,
                                                   offsets, staticStrides);
   }
diff  --git a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
index fe4f44c0b02ab..7e742af754fbe 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
@@ -682,3 +682,73 @@ gpu.module @test_kernel {
     gpu.return
   }
 }
+
+// -----
+gpu.module @test_kernel {
+  // CHECK-LABEL: remove_unit_dim_inst_data
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK: [[cst:%.+]] = arith.constant dense<0.000000e+00> : vector<1x1x32xf32>
+  // CHECK: [[cst_0:%.+]] = arith.constant dense<true> : vector<16xi1>
+  // CHECK: [[cst_1:%.+]] = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex>
+  // CHECK: [[cst_2:%.+]] = arith.constant dense<[128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248]> : vector<16xindex>
+  // CHECK: [[ld_0:%.+]] = xegpu.load [[arg0]][[[cst_1]]], [[cst_0]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : ui64, vector<16xindex>, vector<16xi1> -> vector<16xf32>
+  // CHECK: [[ld_1:%.+]] = xegpu.load [[arg0]][[[cst_2]]], [[cst_0]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : ui64, vector<16xindex>, vector<16xi1> -> vector<16xf32>
+  // CHECK: [[ins_0:%.+]] = vector.insert_strided_slice [[ld_0]], [[cst]] {offsets = [0, 0, 0], strides = [1]} : vector<16xf32> into vector<1x1x32xf32>
+  // CHECK: [[ins_1:%.+]] = vector.insert_strided_slice [[ld_1]], [[ins_0]] {offsets = [0, 0, 16], strides = [1]} : vector<16xf32> into vector<1x1x32xf32>
+  gpu.func @remove_unit_dim_inst_data(%src: ui64) -> vector<1x1x32xf32> {
+      %cst = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 1, 16]>} dense<[[
+      [0,   8,  16,  24,  32,  40,  48,  56,
+      64,  72,  80,  88,  96, 104, 112, 120,
+      128, 136, 144, 152, 160, 168, 176, 184,
+      192, 200, 208, 216, 224, 232, 240, 248]
+      ]]> : vector<1x1x32xindex>
+
+      %mask = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 1, 16]>} dense<true> : vector<1x1x32xi1>
+      %ld = xegpu.load %src[%cst], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<inst_data = [1, 1, 16]>, l1_hint = #xegpu.cache_hint<cached>} : ui64, vector<1x1x32xindex>, vector<1x1x32xi1> -> vector<1x1x32xf32>
+
+      gpu.return %ld : vector<1x1x32xf32>
+  }
+}
+
+// -----
+#l = #xegpu.layout<inst_data = [1, 16]>
+gpu.module @test_kernel {
+  // CHECK-LABEL: load_store_nd_with_offsets
+  // CHECK-SAME: [[arg0:%.+]]: memref<1024x1024xf32>, [[arg1:%.+]]: memref<1024x1024xf32>, [[arg2:%.+]]: memref<1024x1024xf32>
+  // CHECK-DAG: [[cst:%.+]] = arith.constant dense<0.000000e+00> : vector<1x32xf32>
+  // CHECK-DAG: [[c16:%.+]] = arith.constant 16 : index
+  // CHECK-DAG: [[c0:%.+]] = arith.constant 0 : index
+  // CHECK: [[tdesc_a:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<1024x1024xf32> -> !xegpu.tensor_desc<1x16xf32>
+  // CHECK: [[tdesc_b:%.+]] = xegpu.create_nd_tdesc [[arg1]] : memref<1024x1024xf32> -> !xegpu.tensor_desc<1x16xf32>
+  // CHECK: [[tdesc_c:%.+]] = xegpu.create_nd_tdesc [[arg2]] : memref<1024x1024xf32> -> !xegpu.tensor_desc<1x16xf32>
+  // CHECK: [[ld_a0:%.+]] = xegpu.load_nd [[tdesc_a]][[[c0]], [[c0]]]  : !xegpu.tensor_desc<1x16xf32> -> vector<1x16xf32>
+  // CHECK: [[ld_a1:%.+]] = xegpu.load_nd [[tdesc_a]][[[c0]], [[c16]]]  : !xegpu.tensor_desc<1x16xf32> -> vector<1x16xf32>
+  // CHECK: [[ld_b0:%.+]] = xegpu.load_nd [[tdesc_b]][[[c0]], [[c0]]]  : !xegpu.tensor_desc<1x16xf32> -> vector<1x16xf32>
+  // CHECK: [[ld_b1:%.+]] = xegpu.load_nd [[tdesc_b]][[[c0]], [[c16]]]  : !xegpu.tensor_desc<1x16xf32> -> vector<1x16xf32>
+  // CHECK: [[cast_a0:%.+]] = vector.shape_cast [[ld_a0]] : vector<1x16xf32> to vector<16xf32>
+  // CHECK: [[cast_b0:%.+]] = vector.shape_cast [[ld_b0]] : vector<1x16xf32> to vector<16xf32>
+  // CHECK: [[add0:%.+]] = arith.addf [[cast_a0]], [[cast_b0]] : vector<16xf32>
+  // CHECK: [[ins0:%.+]] = vector.insert_strided_slice [[add0]], [[cst]] {offsets = [0, 0], strides = [1]} : vector<16xf32> into vector<1x32xf32>
+  // CHECK: [[cast_a1:%.+]] = vector.shape_cast [[ld_a1]] : vector<1x16xf32> to vector<16xf32>
+  // CHECK: [[cast_b1:%.+]] = vector.shape_cast [[ld_b1]] : vector<1x16xf32> to vector<16xf32>
+  // CHECK: [[add1:%.+]] = arith.addf [[cast_a1]], [[cast_b1]] : vector<16xf32>
+  // CHECK: [[ins1:%.+]] = vector.insert_strided_slice [[add1]], [[ins0]] {offsets = [0, 16], strides = [1]} : vector<16xf32> into vector<1x32xf32>
+  // CHECK: [[ext0:%.+]] = vector.extract_strided_slice [[ins1]] {offsets = [0, 0], sizes = [1, 16], strides = [1, 1]} : vector<1x32xf32> to vector<1x16xf32>
+  // CHECK: [[ext1:%.+]] = vector.extract_strided_slice [[ins1]] {offsets = [0, 16], sizes = [1, 16], strides = [1, 1]} : vector<1x32xf32> to vector<1x16xf32>
+  // CHECK: xegpu.store_nd [[ext0]], [[tdesc_c]][[[c0]], [[c0]]]  : vector<1x16xf32>, !xegpu.tensor_desc<1x16xf32>
+  // CHECK: xegpu.store_nd [[ext1]], [[tdesc_c]][[[c0]], [[c16]]]  : vector<1x16xf32>, !xegpu.tensor_desc<1x16xf32>
+  gpu.func @load_store_nd_with_offsets(%A: memref<1024x1024xf32>, %B: memref<1024x1024xf32>, %C: memref<1024x1024xf32>) {
+    %c0 = arith.constant 0 : index
+
+    %a_tdesc = xegpu.create_nd_tdesc %A : memref<1024x1024xf32> -> !xegpu.tensor_desc<1x32xf32, #l>
+    %b_tdesc = xegpu.create_nd_tdesc %B : memref<1024x1024xf32> -> !xegpu.tensor_desc<1x32xf32, #l>
+    %c_tdesc = xegpu.create_nd_tdesc %C : memref<1024x1024xf32> -> !xegpu.tensor_desc<1x32xf32, #l>
+
+    %a = xegpu.load_nd %a_tdesc[%c0, %c0] : !xegpu.tensor_desc<1x32xf32, #l> -> vector<1x32xf32>
+    %b = xegpu.load_nd %b_tdesc[%c0, %c0] : !xegpu.tensor_desc<1x32xf32, #l> -> vector<1x32xf32>
+
+    %result = arith.addf %a, %b {layout_result_0 = #l} : vector<1x32xf32>
+    xegpu.store_nd %result, %c_tdesc[%c0, %c0] : vector<1x32xf32>, !xegpu.tensor_desc<1x32xf32, #l>
+    gpu.return
+  }
+}
        
    
    
More information about the Mlir-commits
mailing list