[Mlir-commits] [mlir] [MLIR][XeGPU] Decompose unsupported 'vector.transfer_read'-transpose-permutations (PR #182875)

Dmitry Chigarev llvmlistbot at llvm.org
Wed Mar 4 09:06:45 PST 2026


https://github.com/dchigarev updated https://github.com/llvm/llvm-project/pull/182875

>From 3bd5c2d31ffdbf242b25ff35ef0905cd803297ad Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Mon, 23 Feb 2026 14:20:41 +0000
Subject: [PATCH 1/4] [MLIR][XeGPU] Decompose unsupported
 'vector.transfer_read'-transpose-permutations

Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
 .../VectorToXeGPU/VectorToXeGPU.cpp           | 96 +++++++++++++++++--
 .../VectorToXeGPU/transfer-read-to-xegpu.mlir |  3 +-
 2 files changed, 88 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index c81bb4b455b98..cb76d78d3195e 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -99,6 +99,35 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
   return success();
 }
 
+// Checks whether the given `vector.transfer_read` operation can be
+// lowered to a block-load.
+static bool shouldLowerTransferReadToBlockLoad(vector::TransferReadOp readOp) {
+  auto chip = xegpu::getChipStr(readOp);
+  if (chip != "pvc" && chip != "bmg")
+    return false;
+
+  VectorType vecTy = readOp.getVectorType();
+  if (vecTy.getRank() == 1 && !readOp.hasOutOfBoundsDim())
+    return false;
+
+  return true;
+}
+
+// Checks whether the given 'transfer_read with transpose' can be directly
+// lowered to xegpu.load/_nd ops with transpose support.
+static bool isTransferReadTransposeSupported(vector::TransferReadOp readOp) {
+  // Scatter-load always supports transpose-permutations
+  if (!shouldLowerTransferReadToBlockLoad(readOp))
+    return true;
+
+  unsigned minTransposeBitWidth = 32;
+  auto elementType = readOp.getVectorType().getElementType();
+  if (elementType.getIntOrFloatBitWidth() < minTransposeBitWidth)
+    return false;
+
+  return true;
+}
+
 static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
                                                 Location loc,
                                                 xegpu::TensorDescType descType,
@@ -538,9 +567,7 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
     if (failed(transferPreconditions(rewriter, readOp)))
       return failure();
 
-    // TODO:This check needs to be replaced with proper uArch capability check
-    auto chip = xegpu::getChipStr(readOp);
-    if (chip != "pvc" && chip != "bmg") {
+    if (!shouldLowerTransferReadToBlockLoad(readOp)) {
       // lower to scattered load Op if the target HW doesn't have 2d block load
       // support
       // TODO: add support for OutOfBound access
@@ -551,10 +578,6 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
 
     VectorType vecTy = readOp.getVectorType();
 
-    // Lower using load.gather in 1D case
-    if (vecTy.getRank() == 1 && !readOp.hasOutOfBoundsDim())
-      return lowerToScatteredLoadOp(readOp, rewriter);
-
     // Perform common data transfer checks.
     if (failed(storeLoadPreconditions(rewriter, readOp, vecTy)))
       return failure();
@@ -568,9 +591,7 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
     bool isTransposeLoad = !readMap.isMinorIdentity();
 
     Type elementType = vecTy.getElementType();
-    unsigned minTransposeBitWidth = 32;
-    if (isTransposeLoad &&
-        elementType.getIntOrFloatBitWidth() < minTransposeBitWidth)
+    if (isTransposeLoad && !isTransferReadTransposeSupported(readOp))
       return rewriter.notifyMatchFailure(
           readOp, "Unsupported data type for transposition");
 
@@ -841,9 +862,64 @@ struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
   }
 };
 
+// Splits 'vector.transfer_read' with unsupported transpose-permutations
+// into 'transfer_read() + transpose()'.
+struct TransferReadDecomposeUnsupportedTranspose
+    : public OpRewritePattern<vector::TransferReadOp> {
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
+                                PatternRewriter &rewriter) const override {
+    Location loc = readOp.getLoc();
+    AffineMap readMap = readOp.getPermutationMap();
+
+    bool isTransposeLoad = !readMap.isMinorIdentity();
+    if (!isTransposeLoad)
+      return failure();
+
+    bool isTransposeSupported = isTransferReadTransposeSupported(readOp);
+    if (isTransposeSupported)
+      return failure();
+
+    auto resultType = cast<VectorType>(readOp.getResult().getType());
+    if (!resultType)
+      return failure();
+
+    // 'Revert' permutation for the transfer_read result shape to make it
+    // 'untransposed'
+    auto newShape = applyPermutationMap(readMap, resultType.getShape());
+    auto newTransferReadRes =
+        VectorType::get(newShape, resultType.getElementType());
+
+    // Step 1. Create 'plain' transfer_read without transpose
+    auto newReadOp = vector::TransferReadOp::create(
+        rewriter, loc, newTransferReadRes, readOp.getBase(),
+        readOp.getIndices(), AffineMap::get(loc.getContext()),
+        readOp.getPadding(), readOp.getMask(), readOp.getInBoundsAttr());
+
+    // Step 2. Transpose the result of the 'plain' transfer_read
+    auto range = llvm::seq<int64_t>(0, readMap.getResults().size());
+    SmallVector<int64_t> perm(range.begin(), range.end());
+    auto permApplied = applyPermutationMap<int64_t>(readMap, perm);
+    auto transposeOp = vector::TransposeOp::create(
+        rewriter, loc, newReadOp.getResult(), permApplied);
+
+    // Step 3. Replace old transfer_read op with 'transpose()'
+    rewriter.replaceOp(readOp, transposeOp);
+    return success();
+  }
+};
+
 struct ConvertVectorToXeGPUPass
     : public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
   void runOnOperation() override {
+    RewritePatternSet prepareTransferReadPatterns(&getContext());
+    prepareTransferReadPatterns.add<TransferReadDecomposeUnsupportedTranspose>(
+        prepareTransferReadPatterns.getContext());
+    if (failed(applyPatternsGreedily(getOperation(),
+                                     std::move(prepareTransferReadPatterns))))
+      return signalPassFailure();
+
     RewritePatternSet patterns(&getContext());
     populateVectorToXeGPUConversionPatterns(patterns);
     if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
index b58f9b30ed726..5f4aff6019979 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -297,7 +297,8 @@ gpu.func @load_transpose_f16(%source: memref<32x64xf16>,
 }
 
 // LOAD-ND-LABEL:  @load_transpose_f16(
-// LOAD-ND:        vector.transfer_read
+// LOAD-ND:        %[[LOAD:.*]] = xegpu.load_nd
+// LOAD-ND:        vector.transpose %[[LOAD]], [1, 0] : vector<16x8xf16> to vector<8x16xf16>
 
 // LOAD-GATHER-LABEL:  @load_transpose_f16(
 // LOAD-GATHER-SAME:    %[[SRC:.+]]: memref<32x64xf16>,

>From 80f849d9e75a7c6625862001956d6ec88743416f Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Sun, 1 Mar 2026 21:46:53 +0000
Subject: [PATCH 2/4] apply all patterns in the same pattern set

Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
 .../Conversion/VectorToXeGPU/VectorToXeGPU.cpp    | 15 ++++-----------
 1 file changed, 4 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index cb76d78d3195e..087a3fb6c8ff9 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -913,13 +913,6 @@ struct TransferReadDecomposeUnsupportedTranspose
 struct ConvertVectorToXeGPUPass
     : public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
   void runOnOperation() override {
-    RewritePatternSet prepareTransferReadPatterns(&getContext());
-    prepareTransferReadPatterns.add<TransferReadDecomposeUnsupportedTranspose>(
-        prepareTransferReadPatterns.getContext());
-    if (failed(applyPatternsGreedily(getOperation(),
-                                     std::move(prepareTransferReadPatterns))))
-      return signalPassFailure();
-
     RewritePatternSet patterns(&getContext());
     populateVectorToXeGPUConversionPatterns(patterns);
     if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
@@ -931,8 +924,8 @@ struct ConvertVectorToXeGPUPass
 
 void mlir::populateVectorToXeGPUConversionPatterns(
     RewritePatternSet &patterns) {
-  patterns
-      .add<TransferReadLowering, TransferWriteLowering, LoadLowering,
-           ScatterLowering, GatherLowering, StoreLowering, ContractionLowering>(
-          patterns.getContext());
+  patterns.add<TransferReadLowering, TransferWriteLowering, LoadLowering,
+               ScatterLowering, GatherLowering, StoreLowering,
+               ContractionLowering, TransferReadDecomposeUnsupportedTranspose>(
+      patterns.getContext());
 }

>From 4a135be010095e030f1191a1da3a4907c883abf8 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Tue, 3 Mar 2026 18:20:43 +0000
Subject: [PATCH 3/4] move transpose decomposition into an existing pattern

Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
 .../VectorToXeGPU/VectorToXeGPU.cpp           | 164 +++++++-----------
 1 file changed, 60 insertions(+), 104 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 087a3fb6c8ff9..cad58e9423b76 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -99,35 +99,6 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
   return success();
 }
 
-// Checks whether the given `vector.transfer_read` operation can be
-// lowered to a block-load.
-static bool shouldLowerTransferReadToBlockLoad(vector::TransferReadOp readOp) {
-  auto chip = xegpu::getChipStr(readOp);
-  if (chip != "pvc" && chip != "bmg")
-    return false;
-
-  VectorType vecTy = readOp.getVectorType();
-  if (vecTy.getRank() == 1 && !readOp.hasOutOfBoundsDim())
-    return false;
-
-  return true;
-}
-
-// Checks whether the given 'transfer_read with transpose' can be directly
-// lowered to xegpu.load/_nd ops with transpose support.
-static bool isTransferReadTransposeSupported(vector::TransferReadOp readOp) {
-  // Scatter-load always supports transpose-permutations
-  if (!shouldLowerTransferReadToBlockLoad(readOp))
-    return true;
-
-  unsigned minTransposeBitWidth = 32;
-  auto elementType = readOp.getVectorType().getElementType();
-  if (elementType.getIntOrFloatBitWidth() < minTransposeBitWidth)
-    return false;
-
-  return true;
-}
-
 static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
                                                 Location loc,
                                                 xegpu::TensorDescType descType,
@@ -567,7 +538,9 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
     if (failed(transferPreconditions(rewriter, readOp)))
       return failure();
 
-    if (!shouldLowerTransferReadToBlockLoad(readOp)) {
+    // TODO:This check needs to be replaced with proper uArch capability check
+    auto chip = xegpu::getChipStr(readOp);
+    if (chip != "pvc" && chip != "bmg") {
       // lower to scattered load Op if the target HW doesn't have 2d block load
       // support
       // TODO: add support for OutOfBound access
@@ -576,10 +549,14 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
       return lowerToScatteredLoadOp(readOp, rewriter);
     }
 
-    VectorType vecTy = readOp.getVectorType();
+    VectorType loadedVecTy = readOp.getVectorType();
+
+    // Lower using load.gather in 1D case
+    if (loadedVecTy.getRank() == 1 && !readOp.hasOutOfBoundsDim())
+      return lowerToScatteredLoadOp(readOp, rewriter);
 
     // Perform common data transfer checks.
-    if (failed(storeLoadPreconditions(rewriter, readOp, vecTy)))
+    if (failed(storeLoadPreconditions(rewriter, readOp, loadedVecTy)))
       return failure();
 
     bool isOutOfBounds = readOp.hasOutOfBoundsDim();
@@ -590,37 +567,64 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
     AffineMap readMap = readOp.getPermutationMap();
     bool isTransposeLoad = !readMap.isMinorIdentity();
 
-    Type elementType = vecTy.getElementType();
-    if (isTransposeLoad && !isTransferReadTransposeSupported(readOp))
-      return rewriter.notifyMatchFailure(
-          readOp, "Unsupported data type for transposition");
-
-    // If load is transposed, get the base shape for the tensor descriptor.
-    SmallVector<int64_t> descShape(vecTy.getShape());
-    if (isTransposeLoad)
-      std::reverse(descShape.begin(), descShape.end());
+    Type elementType = loadedVecTy.getElementType();
+    unsigned minTransposeBitWidth = 32;
+    // Here we separate two transpose cases:
+    // 1. With transpose attribute in xegpu.load_nd for larger element types
+    // 2. With separate vector.transpose after load_nd for smaller element types
+    bool shouldUseTransposeAttr =
+        isTransposeLoad &&
+        elementType.getIntOrFloatBitWidth() >= minTransposeBitWidth;
+
+    SmallVector<int64_t> descShape(loadedVecTy.getShape());
+    if (isTransposeLoad) {
+      // If load is transposed, then the shape of the source-descriptor
+      // is the opposite from the result-shape. Applying the permutation
+      // to get the reversive shape.
+      auto inversedMap = inversePermutation(readMap);
+      descShape = applyPermutationMap(inversedMap, loadedVecTy.getShape());
+      if (!shouldUseTransposeAttr) {
+        // If we're using a separate vector.transpose instead of the
+        // xegpu.load_nd-transpose_attr, then the loaded vector will be
+        // non-transposed, and the inversive permutation needs to be applied
+        // to the type as well.
+        auto newShape =
+            applyPermutationMap(inversedMap, loadedVecTy.getShape());
+        loadedVecTy = VectorType::get(newShape, loadedVecTy.getElementType());
+      }
+    }
     auto descType = xegpu::TensorDescType::get(
         descShape, elementType, /*array_length=*/1,
         /*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global);
-
     DenseI64ArrayAttr transposeAttr =
-        !isTransposeLoad ? nullptr
-                         : DenseI64ArrayAttr::get(rewriter.getContext(),
-                                                  ArrayRef<int64_t>{1, 0});
+        !shouldUseTransposeAttr
+            ? nullptr
+            : DenseI64ArrayAttr::get(rewriter.getContext(),
+                                     ArrayRef<int64_t>{1, 0});
     auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
         rewriter, loc, readOp.getBase(), getAsOpFoldResult(readOp.getIndices()),
-        vecTy.getRank());
+        loadedVecTy.getRank());
     // By default, no specific caching policy is assigned.
     xegpu::CachePolicyAttr hint = nullptr;
     xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
         rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
 
-    auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices,
-                                          /*packed=*/nullptr, transposeAttr,
-                                          /*l1_hint=*/hint,
-                                          /*l2_hint=*/hint, /*l3_hint=*/hint,
-                                          /*layout=*/nullptr);
-    rewriter.replaceOp(readOp, loadOp);
+    Operation *loadedOp =
+        xegpu::LoadNdOp::create(rewriter, loc, loadedVecTy, ndDesc, indices,
+                                /*packed=*/nullptr, transposeAttr,
+                                /*l1_hint=*/hint,
+                                /*l2_hint=*/hint, /*l3_hint=*/hint,
+                                /*layout=*/nullptr);
+    if (isTransposeLoad && !shouldUseTransposeAttr) {
+      // Transposing the loaded vector with a separate vector.transpose
+      // operation
+      auto range = llvm::seq<int64_t>(0, readMap.getResults().size());
+      SmallVector<int64_t> perm(range.begin(), range.end());
+      auto permApplied = applyPermutationMap<int64_t>(readMap, perm);
+      loadedOp = vector::TransposeOp::create(
+          rewriter, loc, loadedOp->getResult(0), permApplied);
+    }
+    rewriter.replaceOp(readOp, loadedOp);
 
     return success();
   }
@@ -862,54 +866,6 @@ struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
   }
 };
 
-// Splits 'vector.transfer_read' with unsupported transpose-permutations
-// into 'transfer_read() + transpose()'.
-struct TransferReadDecomposeUnsupportedTranspose
-    : public OpRewritePattern<vector::TransferReadOp> {
-  using Base::Base;
-
-  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
-                                PatternRewriter &rewriter) const override {
-    Location loc = readOp.getLoc();
-    AffineMap readMap = readOp.getPermutationMap();
-
-    bool isTransposeLoad = !readMap.isMinorIdentity();
-    if (!isTransposeLoad)
-      return failure();
-
-    bool isTransposeSupported = isTransferReadTransposeSupported(readOp);
-    if (isTransposeSupported)
-      return failure();
-
-    auto resultType = cast<VectorType>(readOp.getResult().getType());
-    if (!resultType)
-      return failure();
-
-    // 'Revert' permutation for the transfer_read result shape to make it
-    // 'untransposed'
-    auto newShape = applyPermutationMap(readMap, resultType.getShape());
-    auto newTransferReadRes =
-        VectorType::get(newShape, resultType.getElementType());
-
-    // Step 1. Create 'plain' transfer_read without transpose
-    auto newReadOp = vector::TransferReadOp::create(
-        rewriter, loc, newTransferReadRes, readOp.getBase(),
-        readOp.getIndices(), AffineMap::get(loc.getContext()),
-        readOp.getPadding(), readOp.getMask(), readOp.getInBoundsAttr());
-
-    // Step 2. Transpose the result of the 'plain' transfer_read
-    auto range = llvm::seq<int64_t>(0, readMap.getResults().size());
-    SmallVector<int64_t> perm(range.begin(), range.end());
-    auto permApplied = applyPermutationMap<int64_t>(readMap, perm);
-    auto transposeOp = vector::TransposeOp::create(
-        rewriter, loc, newReadOp.getResult(), permApplied);
-
-    // Step 3. Replace old transfer_read op with 'transpose()'
-    rewriter.replaceOp(readOp, transposeOp);
-    return success();
-  }
-};
-
 struct ConvertVectorToXeGPUPass
     : public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
   void runOnOperation() override {
@@ -924,8 +880,8 @@ struct ConvertVectorToXeGPUPass
 
 void mlir::populateVectorToXeGPUConversionPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<TransferReadLowering, TransferWriteLowering, LoadLowering,
-               ScatterLowering, GatherLowering, StoreLowering,
-               ContractionLowering, TransferReadDecomposeUnsupportedTranspose>(
-      patterns.getContext());
+  patterns
+      .add<TransferReadLowering, TransferWriteLowering, LoadLowering,
+           ScatterLowering, GatherLowering, StoreLowering, ContractionLowering>(
+          patterns.getContext());
 }

>From 51b6a0f23b78486c77caf04bbea4370626ef911d Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Wed, 4 Mar 2026 17:06:29 +0000
Subject: [PATCH 4/4] fix code formatting

Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
 mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 39c16f0d432b9..719f1af79c220 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -566,7 +566,8 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
 
     // Perform common data transfer checks.
     auto readMemTy = cast<MemRefType>(readOp.getShapedType());
-    if (failed(storeLoadPreconditions(rewriter, readOp, loadedVecTy, readMemTy)))
+    if (failed(
+            storeLoadPreconditions(rewriter, readOp, loadedVecTy, readMemTy)))
       return failure();
 
     bool isOutOfBounds = readOp.hasOutOfBoundsDim();



More information about the Mlir-commits mailing list