[Mlir-commits] [mlir] [MLIR][XeGPU][VectorToXeGPU] Strip leading unit dims for transfer_redd/write (PR #199675)

Andrey Pavlenko llvmlistbot at llvm.org
Tue May 26 09:18:04 PDT 2026


https://github.com/AndreyPavlenko updated https://github.com/llvm/llvm-project/pull/199675

>From 10806b4e44c3b7440440d2d1a5b0e603e7584884 Mon Sep 17 00:00:00 2001
From: Andrey Pavlenko <andrey.a.pavlenko at gmail.com>
Date: Tue, 26 May 2026 13:12:03 +0000
Subject: [PATCH] [MLIR][XeGPU][VectorToXeGPU] Strip leading unit dims for
 transfer_read/write

Strip leading unit dims up to rank-2 to allow lowering to xegpu.load/store_nd
operations, that support only 1D and 2D vectors.
---
 .../VectorToXeGPU/VectorToXeGPU.cpp           | 60 +++++++++++++++----
 .../VectorToXeGPU/transfer-read-to-xegpu.mlir | 25 ++++++++
 .../transfer-write-to-xegpu.mlir              | 24 ++++++++
 3 files changed, 98 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index d8eca95cbf23c..8490a97cd2afe 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -53,7 +53,8 @@ static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter,
                                             MemRefType memTy) {
   // Validate only vector as the basic vector store and load ops guarantee
   // XeGPU-compatible memref source.
-  unsigned vecRank = vecTy.getRank();
+  int64_t vecRank =
+      vecTy.getShape().drop_while([](int64_t d) { return d == 1; }).size();
   if (!(vecRank == 1 || vecRank == 2))
     return rewriter.notifyMatchFailure(op, "Expects 1D or 2D vector");
 
@@ -549,6 +550,29 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
       return failure();
     auto readMemTy = cast<MemRefType>(readOp.getShapedType());
     VectorType loadedVecTy = readOp.getVectorType();
+    std::function<Operation *(Operation *)> castResult = [](Operation *op) {
+      return op;
+    };
+    // Strip leading unit dims up to rank-2
+    if (loadedVecTy.getRank() > 2) {
+      auto shape = loadedVecTy.getShape().drop_back(2);
+      size_t numLeading = 0;
+      while (numLeading < shape.size() && shape[numLeading] == 1)
+        ++numLeading;
+      if (numLeading > 0) {
+        loadedVecTy =
+            VectorType::get(loadedVecTy.getShape().drop_front(numLeading),
+                            loadedVecTy.getElementType());
+        castResult = [&rewriter, &loc, numLeading](Operation *op) {
+          auto result = op->getResult(0);
+          auto ty = cast<VectorType>(result.getType());
+          SmallVector<int64_t> shape(numLeading, 1);
+          shape.append(ty.getShape().begin(), ty.getShape().end());
+          ty = VectorType::get(shape, ty.getElementType());
+          return vector::ShapeCastOp::create(rewriter, loc, ty, result);
+        };
+      }
+    }
     bool isOutOfBounds = readOp.hasOutOfBoundsDim();
     // Check if the memref has address space 3 (shared local memory)
     bool isSharedMemory = xegpu::XeGPUDialect::isSharedMemory(readMemTy);
@@ -582,7 +606,7 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
           rewriter, loc, loadedVecTy, createMemDescOp.getResult(), indices,
           /*layout=*/nullptr);
 
-      rewriter.replaceOp(readOp, loadMatrixOp.getResult());
+      rewriter.replaceOp(readOp, castResult(loadMatrixOp.getOperation()));
       return success();
     }
 
@@ -591,7 +615,7 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
     // Lower to scattered load Op if the target HW doesn't have 2d block load
     // support and the load is not from shared memory.
     if ((chip != "pvc" && chip != "bmg" && chip != "cri") ||
-        readOp.getVectorType().getRank() > 2) {
+        loadedVecTy.getRank() > 2) {
 
       // TODO: add support for OutOfBound access
       if (isOutOfBounds)
@@ -669,7 +693,7 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
       loadedOp = vector::TransposeOp::create(rewriter, loc,
                                              loadedOp->getResult(0), perm);
     }
-    rewriter.replaceOp(readOp, loadedOp);
+    rewriter.replaceOp(readOp, castResult(loadedOp));
 
     return success();
   }
@@ -691,6 +715,20 @@ struct TransferWriteLowering
     // Check if the memref has address space 3 (shared local memory)
     bool isSharedMemory = xegpu::XeGPUDialect::isSharedMemory(writeMemTy);
 
+    // Strip leading unit dims up to rank-2.
+    Value writeVec = writeOp.getVector();
+    if (vecTy.getRank() > 2) {
+      auto shape = vecTy.getShape().drop_back(2);
+      size_t numLeading = 0;
+      while (numLeading < shape.size() && shape[numLeading] == 1)
+        ++numLeading;
+      if (numLeading > 0) {
+        vecTy = VectorType::get(vecTy.getShape().drop_front(numLeading),
+                                vecTy.getElementType());
+        writeVec = vector::ShapeCastOp::create(rewriter, loc, vecTy, writeVec);
+      }
+    }
+
     // For shared local memory (address space 3), use create_mem_desc +
     // store_matrix
     if (isSharedMemory) {
@@ -711,7 +749,7 @@ struct TransferWriteLowering
       SmallVector<OpFoldResult> indices =
           getAsOpFoldResult(writeOp.getIndices());
 
-      xegpu::StoreMatrixOp::create(rewriter, loc, writeOp.getVector(),
+      xegpu::StoreMatrixOp::create(rewriter, loc, writeVec,
                                    createMemDescOp.getResult(), indices,
                                    /*layout=*/nullptr);
 
@@ -724,7 +762,7 @@ struct TransferWriteLowering
     // Lower to scattered store Op if the target HW doesn't have 2d block
     // store support and the memref is not SLM.
     if ((chip != "pvc" && chip != "bmg" && chip != "cri") ||
-        writeOp.getVectorType().getRank() > 2) {
+        vecTy.getRank() > 2) {
 
       // TODO: add support for OutOfBound access
       if (writeOp.hasOutOfBoundsDim())
@@ -752,11 +790,11 @@ struct TransferWriteLowering
     xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
         rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
 
-    auto storeOp = xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(),
-                                            ndDesc, indices,
-                                            /*l1_hint=*/hint,
-                                            /*l2_hint=*/hint, /*l3_hint=*/hint,
-                                            /*layout=*/nullptr);
+    auto storeOp =
+        xegpu::StoreNdOp::create(rewriter, loc, writeVec, ndDesc, indices,
+                                 /*l1_hint=*/hint,
+                                 /*l2_hint=*/hint, /*l3_hint=*/hint,
+                                 /*layout=*/nullptr);
     rewriter.replaceOp(writeOp, storeOp);
 
     return success();
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
index 7ba576313c6d4..1e31b8b9fd168 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -613,3 +613,28 @@ gpu.func @load_1D_vector_addrspace3_unsupported(%source: memref<32xf32, 3>,
 // LOAD-GATHER: vector.transfer_read
 
 }
+
+// -----
+gpu.module @xevm_module {
+gpu.func @load_strip_leading_unit_dims(
+    %src: memref<1x24x1024x1024xf16>,
+    %i: index, %j: index) -> vector<1x1x1x1024xf16> {
+  %pad = arith.constant 0.0 : f16
+  %0 = vector.transfer_read %src[%i, %j, %j, %i], %pad
+    {in_bounds = [true, true, true, true]}
+    : memref<1x24x1024x1024xf16>, vector<1x1x1x1024xf16>
+  gpu.return %0 : vector<1x1x1x1024xf16>
+}
+
+// LOAD-ND-LABEL:  @load_strip_leading_unit_dims(
+// LOAD-ND-SAME:   %[[SRC:.+]]: memref<1x24x1024x1024xf16>,
+// LOAD-ND-SAME:   %[[I:.+]]: index, %[[J:.+]]: index
+// LOAD-ND:        xegpu.create_nd_tdesc
+// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd {{.*}} -> vector<1x1024xf16>
+// LOAD-ND:        %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x1024xf16> to vector<1x1x1x1024xf16>
+// LOAD-ND:        gpu.return %[[CAST]]
+
+// LOAD-GATHER-LABEL:  @load_strip_leading_unit_dims(
+// LOAD-GATHER:        xegpu.load {{.*}} -> vector<1x1x1x1024xf16>
+
+}
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
index a7cd3d7652d85..c45896d149c8b 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
@@ -412,3 +412,27 @@ gpu.func @store_1D_vector_addrspace3_unsupported(%vec: vector<8xf32>,
 // STORE-SCATTER: vector.transfer_write
 
 }
+
+// -----
+gpu.module @xevm_module {
+gpu.func @store_strip_leading_unit_dims(
+    %vec: vector<1x1x1x1024xf16>,
+    %dst: memref<1x24x1152x1024xf16>,
+    %i: index, %j: index) {
+  vector.transfer_write %vec, %dst[%i, %j, %j, %i]
+    {in_bounds = [true, true, true, true]}
+    : vector<1x1x1x1024xf16>, memref<1x24x1152x1024xf16>
+  gpu.return
+}
+
+// STORE-ND-LABEL: @store_strip_leading_unit_dims(
+// STORE-ND-SAME:  %[[VEC:.+]]: vector<1x1x1x1024xf16>,
+// STORE-ND-SAME:  %[[DST:.+]]: memref<1x24x1152x1024xf16>,
+// STORE-ND:       %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x1x1x1024xf16> to vector<1x1024xf16>
+// STORE-ND:       xegpu.create_nd_tdesc
+// STORE-ND:       xegpu.store_nd %[[CAST]], {{.*}} : vector<1x1024xf16>
+
+// STORE-SCATTER-LABEL: @store_strip_leading_unit_dims(
+// STORE-SCATTER:       xegpu.store {{.*}} : vector<1x1x1x1024xf16>
+
+}



More information about the Mlir-commits mailing list