[Mlir-commits] [mlir] [MLIR][XeGPU][VectorToXeGPU] Fixed lowering of transfer_read/write for rank > 2 (PR #193308)
Dmitry Chigarev
llvmlistbot at llvm.org
Fri Apr 24 10:04:19 PDT 2026
https://github.com/dchigarev updated https://github.com/llvm/llvm-project/pull/193308
>From fb0dc5f903aeba0421c554fb3f4988929d587e37 Mon Sep 17 00:00:00 2001
From: Andrey Pavlenko <andrey.a.pavlenko at gmail.com>
Date: Tue, 21 Apr 2026 20:04:43 +0000
Subject: [PATCH] [MLIR][XeGPU] Fixed lowerig of vector.transfer_read/write to
XeGPU for rank > 2
If rank > 2, load gather/store scatter are used.
Increased value type rank to 8.
---
.../mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 4 +-
.../VectorToXeGPU/VectorToXeGPU.cpp | 6 +-
.../VectorToXeGPU/transfer-read-to-xegpu.mlir | 93 +++++++++++--------
.../transfer-write-to-xegpu.mlir | 64 ++++++++-----
4 files changed, 105 insertions(+), 62 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 33eab14e9dfd8..d73c33e8fddae 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -25,10 +25,10 @@ def XeGPU_DpasOprType: FixedVectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>;
def XeGPU_DpasResType: FixedVectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>;
def XeGPU_OffsetType: FixedVectorOfNonZeroRankOf<[Index]>;
def XeGPU_MaskType: FixedVectorOfNonZeroRankOf<[I1]>;
-def XeGPU_ValueType: VectorOfRankAndType<[1,2,3,4,5,6], [XeGPU_ScalarType]>;
+def XeGPU_ValueType: VectorOfRankAndType<[1,2,3,4,5,6,7,8], [XeGPU_ScalarType]>;
def XeGPU_ValueOrScalarType : AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>;
def XeGPU_VectorOrScalarType
- : AnyTypeOf<[VectorOfRankAndType<[1,2,3,4,5,6], [XeGPU_ScalarType, Index]>, XeGPU_ScalarType]>;
+ : AnyTypeOf<[VectorOfRankAndType<[1,2,3,4,5,6,7,8], [XeGPU_ScalarType, Index]>, XeGPU_ScalarType]>;
def XeGPU_GatherScatterBaseAddrType
: AnyTypeOf<[MemRefRankOf<[XeGPU_ScalarType], [1]>, XeGPU_PointerType]>;
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index bbb6340f14c51..c686eb50a5448 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -550,7 +550,8 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
// TODO:This check needs to be replaced with proper uArch capability check
auto chip = xegpu::getChipStr(readOp);
- if (chip != "pvc" && chip != "bmg") {
+ if ((chip != "pvc" && chip != "bmg") ||
+ readOp.getVectorType().getRank() > 2) {
// lower to scattered load Op if the target HW doesn't have 2d block load
// support
// TODO: add support for OutOfBound access
@@ -634,7 +635,8 @@ struct TransferWriteLowering
// TODO:This check needs to be replaced with proper uArch capability check
auto chip = xegpu::getChipStr(writeOp);
- if (chip != "pvc" && chip != "bmg") {
+ if ((chip != "pvc" && chip != "bmg") ||
+ writeOp.getVectorType().getRank() > 2) {
// lower to scattered store Op if the target HW doesn't have 2d block
// store support
// TODO: add support for OutOfBound access
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
index 1a19c8a13f120..c4c7bc1664823 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s --xevm-attach-target='module=xevm_* O=3 chip=pvc' -convert-vector-to-xegpu -split-input-file | FileCheck %s --check-prefix=LOAD-ND
-// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s --check-prefix=LOAD-GATHER
+// RUN: mlir-opt %s --xevm-attach-target='module=xevm_* O=3 chip=pvc' -convert-vector-to-xegpu -split-input-file | FileCheck %s --check-prefixes=LOAD-ND,CHECK
+// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s --check-prefixes=LOAD-GATHER,CHECK
gpu.module @xevm_module {
gpu.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vector<8xf32> {
@@ -235,24 +235,21 @@ gpu.func @load_dynamic_source3(%source: memref<?x?x?x?x?xf32>,
gpu.return %0 : vector<2x4x8x16xf32>
}
-// LOAD-ND-LABEL: @load_dynamic_source3(
-// LOAD-ND: vector.transfer_read
-
-// LOAD-GATHER-LABEL: @load_dynamic_source3(
-// LOAD-GATHER-SAME: %[[SRC:.+]]: memref<?x?x?x?x?xf32>
-// LOAD-GATHER: %[[CST:.+]] = arith.constant dense<true> : vector<2x4x8x16xi1>
-// LOAD-GATHER: memref.extract_strided_metadata %[[SRC]] : memref<?x?x?x?x?xf32> -> memref<f32>, index, index, index, index, index, index, index, index, index, index, index
-// LOAD-GATHER-COUNT4: vector.step
-// LOAD-GATHER-COUNT3: vector.broadcast
-// LOAD-GATHER-COUNT4: vector.shape_cast
-// LOAD-GATHER-COUNT4: vector.broadcast {{.*}} : vector<2x4x8x16xindex>
-// LOAD-GATHER-COUNT3: arith.addi {{.*}} : vector<2x4x8x16xindex>
-// LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}} : index to vector<2x4x8x16xindex>
-// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} : vector<2x4x8x16xindex>
-// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<?x?x?x?x?xf32> -> index
-// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
-// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<2x4x8x16xindex>, vector<2x4x8x16xi1> -> vector<2x4x8x16xf32>
-// LOAD-GATHER: return %[[VEC]]
+// CHECK-LABEL: @load_dynamic_source3(
+// CHECK-SAME: %[[SRC:.+]]: memref<?x?x?x?x?xf32>
+// CHECK: %[[CST:.+]] = arith.constant dense<true> : vector<2x4x8x16xi1>
+// CHECK: memref.extract_strided_metadata %[[SRC]] : memref<?x?x?x?x?xf32> -> memref<f32>, index, index, index, index, index, index, index, index, index, index, index
+// CHECK-COUNT4: vector.step
+// CHECK-COUNT3: vector.broadcast
+// CHECK-COUNT4: vector.shape_cast
+// CHECK-COUNT4: vector.broadcast {{.*}} : vector<2x4x8x16xindex>
+// CHECK-COUNT3: arith.addi {{.*}} : vector<2x4x8x16xindex>
+// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}} : index to vector<2x4x8x16xindex>
+// CHECK: %[[IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} : vector<2x4x8x16xindex>
+// CHECK: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<?x?x?x?x?xf32> -> index
+// CHECK: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<2x4x8x16xindex>, vector<2x4x8x16xi1> -> vector<2x4x8x16xf32>
+// CHECK: return %[[VEC]]
}
// -----
@@ -265,24 +262,46 @@ gpu.func @load_high_dim_vector(%source: memref<16x32x64xf32>,
gpu.return %0 : vector<8x16x32xf32>
}
-// LOAD-ND-LABEL: @load_high_dim_vector(
-// LOAD-ND: vector.transfer_read
+// CHECK-LABEL: @load_high_dim_vector(
+// CHECK: %[[CST:.+]] = arith.constant dense<true> : vector<8x16x32xi1>
+// CHECK: %[[CST_0:.+]] = arith.constant dense<64> : vector<16xindex>
+// CHECK: %[[CST_1:.+]] = arith.constant dense<2048> : vector<8xindex>
+// CHECK: %[[C2048:.+]] = arith.constant 2048 : index
+// CHECK: %[[C64:.+]] = arith.constant 64 : index
+// CHECK-COUNT3: vector.step
+// CHECK-COUNT3: vector.shape_cast
+// CHECK-COUNT3: vector.broadcast {{.*}} : vector<8x16x32xindex>
+// CHECK-COUNT2: arith.addi {{.*}} : vector<8x16x32xindex>
+// CHECK: %[[BCASTOFF:.+]] = vector.broadcast {{.*}} : index to vector<8x16x32xindex>
+// CHECK: %[[IDX:.+]] = arith.addi %[[BCASTOFF]], {{.*}} : vector<8x16x32xindex>
+// CHECK: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %arg0 : memref<16x32x64xf32> -> index
+// CHECK: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : i64, vector<8x16x32xindex>, vector<8x16x32xi1> -> vector<8x16x32xf32>
-// LOAD-GATHER-LABEL: @load_high_dim_vector(
-// LOAD-GATHER: %[[CST:.+]] = arith.constant dense<true> : vector<8x16x32xi1>
-// LOAD-GATHER: %[[CST_0:.+]] = arith.constant dense<64> : vector<16xindex>
-// LOAD-GATHER: %[[CST_1:.+]] = arith.constant dense<2048> : vector<8xindex>
-// LOAD-GATHER: %[[C2048:.+]] = arith.constant 2048 : index
-// LOAD-GATHER: %[[C64:.+]] = arith.constant 64 : index
-// LOAD-GATHER-COUNT3: vector.step
-// LOAD-GATHER-COUNT3: vector.shape_cast
-// LOAD-GATHER-COUNT3: vector.broadcast {{.*}} : vector<8x16x32xindex>
-// LOAD-GATHER-COUNT2: arith.addi {{.*}} : vector<8x16x32xindex>
-// LOAD-GATHER: %[[BCASTOFF:.+]] = vector.broadcast {{.*}} : index to vector<8x16x32xindex>
-// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[BCASTOFF]], {{.*}} : vector<8x16x32xindex>
-// LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %arg0 : memref<16x32x64xf32> -> index
-// LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
-// LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : i64, vector<8x16x32xindex>, vector<8x16x32xi1> -> vector<8x16x32xf32>
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @load_8D_vector(%source: memref<2x2x2x2x2x2x2x2xf32>,
+ %offset: index) -> vector<2x2x2x2x2x2x2x2xf32> {
+ %c0 = arith.constant 0.0 : f32
+ %0 = vector.transfer_read %source[%offset, %offset, %offset, %offset, %offset, %offset, %offset, %offset], %c0
+ {in_bounds = [true, true, true, true, true, true, true, true]} : memref<2x2x2x2x2x2x2x2xf32>, vector<2x2x2x2x2x2x2x2xf32>
+ gpu.return %0 : vector<2x2x2x2x2x2x2x2xf32>
+}
+
+// CHECK-LABEL: @load_8D_vector(
+// CHECK-SAME: %[[SRC:.+]]: memref<2x2x2x2x2x2x2x2xf32>,
+// CHECK: %[[CST:.+]] = arith.constant dense<true> : vector<2x2x2x2x2x2x2x2xi1>
+// CHECK-COUNT8: vector.step
+// CHECK-COUNT7: vector.shape_cast
+// CHECK-COUNT8: vector.broadcast {{.*}} : vector<2x2x2x2x2x2x2x2xindex>
+// CHECK-COUNT7: arith.addi {{.*}} : vector<2x2x2x2x2x2x2x2xindex>
+// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}} : index to vector<2x2x2x2x2x2x2x2xindex>
+// CHECK: %[[IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} : vector<2x2x2x2x2x2x2x2xindex>
+// CHECK: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<2x2x2x2x2x2x2x2xf32> -> index
+// CHECK: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : i64, vector<2x2x2x2x2x2x2x2xindex>, vector<2x2x2x2x2x2x2x2xi1> -> vector<2x2x2x2x2x2x2x2xf32>
}
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
index 66da64225678e..0cda8d57b34e5 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s --xevm-attach-target='module=xevm_* O=3 chip=pvc' -convert-vector-to-xegpu -split-input-file | FileCheck %s --check-prefix=STORE-ND
-// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s --check-prefix=STORE-SCATTER
+// RUN: mlir-opt %s --xevm-attach-target='module=xevm_* O=3 chip=pvc' -convert-vector-to-xegpu -split-input-file | FileCheck %s --check-prefixes=STORE-ND,CHECK
+// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s --check-prefixes=STORE-SCATTER,CHECK
gpu.module @xevm_module {
@@ -187,26 +187,48 @@ gpu.func @store_high_dim_vector(%vec: vector<8x16x32xf32>,
gpu.return
}
-// STORE-ND-LABEL: @store_high_dim_vector(
-// STORE-ND: vector.transfer_write
+// CHECK-LABEL: @store_high_dim_vector(
+// CHECK-SAME: %[[VEC:.+]]: vector<8x16x32xf32>,
+// CHECK-SAME: %[[SRC:.+]]: memref<16x32x64xf32>
+// CHECK: %[[CST:.+]] = arith.constant dense<true> : vector<8x16x32xi1>
+// CHECK: %[[CST_0:.+]] = arith.constant dense<64> : vector<16xindex>
+// CHECK: %[[CST_1:.+]] = arith.constant dense<2048> : vector<8xindex>
+// CHECK: %[[C2048:.+]] = arith.constant 2048 : index
+// CHECK: %[[C64:.+]] = arith.constant 64 : index
+// CHECK-COUNT3: vector.step
+// CHECK-COUNT3: vector.shape_cast
+// CHECK-COUNT3: vector.broadcast {{.*}} : vector<8x16x32xindex>
+// CHECK-COUNT2: arith.addi {{.*}} : vector<8x16x32xindex>
+// CHECK: %[[BCASTOFF:.+]] = vector.broadcast {{.*}} : index to vector<8x16x32xindex>
+// CHECK: %[[IDX:.+]] = arith.addi %[[BCASTOFF]], {{.*}} : vector<8x16x32xindex>
+// CHECK: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<16x32x64xf32> -> index
+// CHECK: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// CHECK: xegpu.store %[[VEC]], %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : vector<8x16x32xf32>, i64, vector<8x16x32xindex>, vector<8x16x32xi1>
+}
-// STORE-SCATTER-LABEL: @store_high_dim_vector(
-// STORE-SCATTER-SAME: %[[VEC:.+]]: vector<8x16x32xf32>,
-// STORE-SCATTER-SAME: %[[SRC:.+]]: memref<16x32x64xf32>
-// STORE-SCATTER: %[[CST:.+]] = arith.constant dense<true> : vector<8x16x32xi1>
-// STORE-SCATTER: %[[CST_0:.+]] = arith.constant dense<64> : vector<16xindex>
-// STORE-SCATTER: %[[CST_1:.+]] = arith.constant dense<2048> : vector<8xindex>
-// STORE-SCATTER: %[[C2048:.+]] = arith.constant 2048 : index
-// STORE-SCATTER: %[[C64:.+]] = arith.constant 64 : index
-// STORE-SCATTER-COUNT3: vector.step
-// STORE-SCATTER-COUNT3: vector.shape_cast
-// STORE-SCATTER-COUNT3: vector.broadcast {{.*}} : vector<8x16x32xindex>
-// STORE-SCATTER-COUNT2: arith.addi {{.*}} : vector<8x16x32xindex>
-// STORE-SCATTER: %[[BCASTOFF:.+]] = vector.broadcast {{.*}} : index to vector<8x16x32xindex>
-// STORE-SCATTER: %[[IDX:.+]] = arith.addi %[[BCASTOFF]], {{.*}} : vector<8x16x32xindex>
-// STORE-SCATTER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<16x32x64xf32> -> index
-// STORE-SCATTER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
-// STORE-SCATTER: xegpu.store %[[VEC]], %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : vector<8x16x32xf32>, i64, vector<8x16x32xindex>, vector<8x16x32xi1>
+// -----
+gpu.module @xevm_module {
+gpu.func @store_8D_vector(%vec: vector<2x2x2x2x2x2x2x2xf32>,
+ %source: memref<2x2x2x2x2x2x2x2xf32>, %offset: index) {
+ vector.transfer_write %vec, %source[%offset, %offset, %offset, %offset, %offset, %offset, %offset, %offset]
+ {in_bounds = [true, true, true, true, true, true, true, true]}
+ : vector<2x2x2x2x2x2x2x2xf32>, memref<2x2x2x2x2x2x2x2xf32>
+ gpu.return
+}
+
+// CHECK-LABEL: @store_8D_vector(
+// CHECK-SAME: %[[VEC:.+]]: vector<2x2x2x2x2x2x2x2xf32>,
+// CHECK-SAME: %[[SRC:.+]]: memref<2x2x2x2x2x2x2x2xf32>
+// CHECK: %[[CST:.+]] = arith.constant dense<true> : vector<2x2x2x2x2x2x2x2xi1>
+// CHECK-COUNT8: vector.step
+// CHECK-COUNT7: vector.shape_cast
+// CHECK-COUNT8: vector.broadcast {{.*}} : vector<2x2x2x2x2x2x2x2xindex>
+// CHECK-COUNT7: arith.addi {{.*}} : vector<2x2x2x2x2x2x2x2xindex>
+// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}} : index to vector<2x2x2x2x2x2x2x2xindex>
+// CHECK: %[[IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} : vector<2x2x2x2x2x2x2x2xindex>
+// CHECK: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<2x2x2x2x2x2x2x2xf32> -> index
+// CHECK: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
+// CHECK: xegpu.store %[[VEC]], %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : vector<2x2x2x2x2x2x2x2xf32>, i64, vector<2x2x2x2x2x2x2x2xindex>, vector<2x2x2x2x2x2x2x2xi1>
}
// -----
More information about the Mlir-commits
mailing list