[Mlir-commits] [mlir] [MLIR][XeGPU] Decompose unsupported 'vector.transfer_read'-transpose-permutations (PR #182875)

Dmitry Chigarev llvmlistbot at llvm.org
Mon Feb 23 07:52:18 PST 2026


https://github.com/dchigarev created https://github.com/llvm/llvm-project/pull/182875

The PR adds a pattern to `vector-to-xegpu` pass that decomposes `vector.transfer_read` with unsupported transpose-permutations (unsupported element-type) into `vector.transfer_read + vector.transpose`:

Example:

```mlir
// input-ir:
  %0 = vector.transfer_read %source[%offset, %offset], %c0
    {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
    in_bounds = [true, true]} : memref<32x64xf16>, vector<8x16xf16>

// mlir-opt %s --convert-vector-to-xegpu
// before PR (no conversion because of unsupported type):
  %0 = vector.transfer_read %source[%offset, %offset], %c0
    {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
    in_bounds = [true, true]} : memref<32x64xf16>, vector<8x16xf16>

// mlir-opt %s --convert-vector-to-xegpu
// after PR (decomposed + converted):
  %0 = xegpu.load_nd %source[%offset, %offset]
  %1 = vector.transpose %0
```

>From 4fbd8184c38afddf4afd78d33d59d5a1d6ce6b1d Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Mon, 23 Feb 2026 14:20:41 +0000
Subject: [PATCH] [MLIR][XeGPU] Decompose unsupported
 'vector.transfer_read'-transpose-permutations

Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
 .../VectorToXeGPU/VectorToXeGPU.cpp           | 96 +++++++++++++++++--
 .../VectorToXeGPU/transfer-read-to-xegpu.mlir |  3 +-
 2 files changed, 88 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index c81bb4b455b98..cb76d78d3195e 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -99,6 +99,35 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
   return success();
 }
 
+// Checks whether the given `vector.transfer_read` operation can be
+// lowered to a block-load.
+static bool shouldLowerTransferReadToBlockLoad(vector::TransferReadOp readOp) {
+  auto chip = xegpu::getChipStr(readOp);
+  if (chip != "pvc" && chip != "bmg")
+    return false;
+
+  VectorType vecTy = readOp.getVectorType();
+  if (vecTy.getRank() == 1 && !readOp.hasOutOfBoundsDim())
+    return false;
+
+  return true;
+}
+
+// Checks whether the given 'transfer_read with transpose' can be directly
+// lowered to xegpu.load/_nd ops with transpose support.
+static bool isTransferReadTransposeSupported(vector::TransferReadOp readOp) {
+  // Scatter-load always supports transpose-permutations
+  if (!shouldLowerTransferReadToBlockLoad(readOp))
+    return true;
+
+  unsigned minTransposeBitWidth = 32;
+  auto elementType = readOp.getVectorType().getElementType();
+  if (elementType.getIntOrFloatBitWidth() < minTransposeBitWidth)
+    return false;
+
+  return true;
+}
+
 static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
                                                 Location loc,
                                                 xegpu::TensorDescType descType,
@@ -538,9 +567,7 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
     if (failed(transferPreconditions(rewriter, readOp)))
       return failure();
 
-    // TODO:This check needs to be replaced with proper uArch capability check
-    auto chip = xegpu::getChipStr(readOp);
-    if (chip != "pvc" && chip != "bmg") {
+    if (!shouldLowerTransferReadToBlockLoad(readOp)) {
       // lower to scattered load Op if the target HW doesn't have 2d block load
       // support
       // TODO: add support for OutOfBound access
@@ -551,10 +578,6 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
 
     VectorType vecTy = readOp.getVectorType();
 
-    // Lower using load.gather in 1D case
-    if (vecTy.getRank() == 1 && !readOp.hasOutOfBoundsDim())
-      return lowerToScatteredLoadOp(readOp, rewriter);
-
     // Perform common data transfer checks.
     if (failed(storeLoadPreconditions(rewriter, readOp, vecTy)))
       return failure();
@@ -568,9 +591,7 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
     bool isTransposeLoad = !readMap.isMinorIdentity();
 
     Type elementType = vecTy.getElementType();
-    unsigned minTransposeBitWidth = 32;
-    if (isTransposeLoad &&
-        elementType.getIntOrFloatBitWidth() < minTransposeBitWidth)
+    if (isTransposeLoad && !isTransferReadTransposeSupported(readOp))
       return rewriter.notifyMatchFailure(
           readOp, "Unsupported data type for transposition");
 
@@ -841,9 +862,64 @@ struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
   }
 };
 
+// Splits 'vector.transfer_read' with unsupported transpose-permutations
+// into 'transfer_read() + transpose()'.
+struct TransferReadDecomposeUnsupportedTranspose
+    : public OpRewritePattern<vector::TransferReadOp> {
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
+                                PatternRewriter &rewriter) const override {
+    Location loc = readOp.getLoc();
+    AffineMap readMap = readOp.getPermutationMap();
+
+    bool isTransposeLoad = !readMap.isMinorIdentity();
+    if (!isTransposeLoad)
+      return failure();
+
+    bool isTransposeSupported = isTransferReadTransposeSupported(readOp);
+    if (isTransposeSupported)
+      return failure();
+
+    auto resultType = cast<VectorType>(readOp.getResult().getType());
+    if (!resultType)
+      return failure();
+
+    // 'Revert' permutation for the transfer_read result shape to make it
+    // 'untransposed'
+    auto newShape = applyPermutationMap(readMap, resultType.getShape());
+    auto newTransferReadRes =
+        VectorType::get(newShape, resultType.getElementType());
+
+    // Step 1. Create 'plain' transfer_read without transpose
+    auto newReadOp = vector::TransferReadOp::create(
+        rewriter, loc, newTransferReadRes, readOp.getBase(),
+        readOp.getIndices(), AffineMap::get(loc.getContext()),
+        readOp.getPadding(), readOp.getMask(), readOp.getInBoundsAttr());
+
+    // Step 2. Transpose the result of the 'plain' transfer_read
+    auto range = llvm::seq<int64_t>(0, readMap.getResults().size());
+    SmallVector<int64_t> perm(range.begin(), range.end());
+    auto permApplied = applyPermutationMap<int64_t>(readMap, perm);
+    auto transposeOp = vector::TransposeOp::create(
+        rewriter, loc, newReadOp.getResult(), permApplied);
+
+    // Step 3. Replace old transfer_read op with 'transpose()'
+    rewriter.replaceOp(readOp, transposeOp);
+    return success();
+  }
+};
+
 struct ConvertVectorToXeGPUPass
     : public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
   void runOnOperation() override {
+    RewritePatternSet prepareTransferReadPatterns(&getContext());
+    prepareTransferReadPatterns.add<TransferReadDecomposeUnsupportedTranspose>(
+        prepareTransferReadPatterns.getContext());
+    if (failed(applyPatternsGreedily(getOperation(),
+                                     std::move(prepareTransferReadPatterns))))
+      return signalPassFailure();
+
     RewritePatternSet patterns(&getContext());
     populateVectorToXeGPUConversionPatterns(patterns);
     if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
index b58f9b30ed726..5f4aff6019979 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -297,7 +297,8 @@ gpu.func @load_transpose_f16(%source: memref<32x64xf16>,
 }
 
 // LOAD-ND-LABEL:  @load_transpose_f16(
-// LOAD-ND:        vector.transfer_read
+// LOAD-ND:        %[[LOAD:.*]] = xegpu.load_nd
+// LOAD-ND:        vector.transpose %[[LOAD]], [1, 0] : vector<16x8xf16> to vector<8x16xf16>
 
 // LOAD-GATHER-LABEL:  @load_transpose_f16(
 // LOAD-GATHER-SAME:    %[[SRC:.+]]: memref<32x64xf16>,



More information about the Mlir-commits mailing list