[Mlir-commits] [mlir] 66134e7 - [MLIR][XeVM] Improve matrix ops lowering (#170268)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Dec 8 09:29:19 PST 2025
Author: Artem Kroviakov
Date: 2025-12-08T18:29:15+01:00
New Revision: 66134e7136154cbb839484afd8f356cc4ed4a021
URL: https://github.com/llvm/llvm-project/commit/66134e7136154cbb839484afd8f356cc4ed4a021
DIFF: https://github.com/llvm/llvm-project/commit/66134e7136154cbb839484afd8f356cc4ed4a021.diff
LOG: [MLIR][XeVM] Improve matrix ops lowering (#170268)
Added:
Modified:
mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 2b162ec3f3bf4..54254be007788 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -676,16 +676,25 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
Value baseAddr32 = adaptor.getMemDesc();
Value mdescVal = op.getMemDesc();
// Load result or Store value Type can be vector or scalar.
- Value data;
- if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>)
- data = op.getResult();
- else
- data = adaptor.getData();
- VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType());
+ Type dataTy;
+ if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
+ Type resType = op.getResult().getType();
+ // Some transforms may leave unit dimension in the 2D vector, adaptors do
+ // not catch it for results.
+ if (auto vecType = dyn_cast<VectorType>(resType)) {
+ auto nonUnitDims = llvm::count_if(vecType.getShape(),
+ [](int64_t d) { return d != 1; });
+ assert(nonUnitDims <= 1 &&
+ "Expected either 1D vector or nD with unit dimensions");
+ resType = VectorType::get({vecType.getNumElements()},
+ vecType.getElementType());
+ }
+ dataTy = resType;
+ } else
+ dataTy = adaptor.getData().getType();
+ VectorType valOrResVecTy = dyn_cast<VectorType>(dataTy);
if (!valOrResVecTy)
- valOrResVecTy = VectorType::get(1, data.getType());
- if (valOrResVecTy.getShape().size() != 1)
- return rewriter.notifyMatchFailure(op, "Expected 1D data vector.");
+ valOrResVecTy = VectorType::get(1, dataTy);
int64_t elemBitWidth =
valOrResVecTy.getElementType().getIntOrFloatBitWidth();
@@ -1176,6 +1185,7 @@ struct ConvertXeGPUToXeVMPass
};
typeConverter.addSourceMaterialization(
singleElementVectorMaterializationCast);
+ typeConverter.addSourceMaterialization(vectorMaterializationCast);
typeConverter.addTargetMaterialization(memrefMaterializationCast);
typeConverter.addTargetMaterialization(ui32MaterializationCast);
typeConverter.addTargetMaterialization(ui64MaterializationCast);
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
index ac95a1a5707ea..3a3769f3a4f70 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
@@ -30,7 +30,7 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
gpu.func @load_store_matrix_plain_2d_input(%arg0: memref<8192xi8, 3>) -> f32 {
%c0 = arith.constant 0 : index
%view = memref.view %arg0[%c0][]: memref<8192xi8, 3> to memref<64x32xf32, 3>
-
+
%subview = memref.subview %view[32, 0] [32, 32] [1, 1] : memref<64x32xf32, 3> to memref<32x32xf32, strided<[32, 1], offset: 1024>, 3>
%0 = xegpu.create_mem_desc %subview : memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> -> !xegpu.mem_desc<32x32xf32>
@@ -43,7 +43,7 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
//CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32
%tid_x = gpu.thread_id x
-
+
%1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> f32
//CHECK: llvm.store {{.*}}, {{.*}} : f32, !llvm.ptr<3>
@@ -81,15 +81,15 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
//CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
//CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> f16
-
+
%tid_x = gpu.thread_id x
%c13 = arith.constant 13 : index
%1 = xegpu.load_matrix %0[%c13, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> f16
//CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3>
-
- xegpu.store_matrix %1, %0[%c13, %tid_x]: f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index
+
+ xegpu.store_matrix %1, %0[%c13, %tid_x]: f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index
gpu.return %1: f16
}
@@ -102,12 +102,12 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
//CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index
//CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
-
+
//CHECK: %[[tid_x:.*]] = gpu.thread_id x
//CHECK: %[[c19:.*]] = arith.constant 19 : index
%tid_x = gpu.thread_id x
%c19 = arith.constant 19: index
-
+
//CHECK: %[[c16:.*]] = arith.constant 16 : index
//CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index
//CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index
@@ -127,10 +127,10 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
//CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
//CHECK: %[[loaded:.*]] = llvm.load {{.*}} : !llvm.ptr<3> -> f16
%1 = xegpu.load_matrix %0[%c19, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> f16
-
+
//CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3>
xegpu.store_matrix %1, %0[%c19, %tid_x]: f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index
-
+
//CHECK: gpu.return %[[loaded]] : f16
gpu.return %1: f16
}
@@ -161,7 +161,7 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
//CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
//CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> vector<8xf16>
-
+
%tid_x = gpu.thread_id x
%c16 = arith.constant 16 : index
%1 = xegpu.load_matrix %0[%c16, %tid_x] : !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> vector<8xf16>
@@ -172,7 +172,7 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
gpu.return %1: vector<8xf16>
}
-
+
// e.g. for mem_desc<32x64xf16, @block=[16, 16]>
// its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
//CHECK-LABEL: load_store_matrix_blocked_subgroupblockio
@@ -214,11 +214,22 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
%1 = xegpu.load_matrix %0[%c16, %c48] {subgroup_block_io}: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> vector<8xf16>
//CHECK: %[[storeDataI16:.*]] = vector.bitcast %[[loaded]] : vector<8xf16> to vector<8xi16>
- //CHECK: xevm.blockstore %[[ptr]], %[[storeDataI16]] : (!llvm.ptr<3>, vector<8xi16>)
+ //CHECK: xevm.blockstore %[[ptr]], %[[storeDataI16]] : (!llvm.ptr<3>, vector<8xi16>)
xegpu.store_matrix %1, %0[%c16, %c48] {subgroup_block_io}: vector<8xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index
gpu.return %1: vector<8xf16>
}
+ gpu.func @matrix_vector_materialization(%matrixdesc : !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>) {
+ // CHECK: %[[XEVM_VECTOR:.*]] = llvm.load %{{.*}} : !llvm.ptr<3> -> vector<16xf16>
+ // CHECK: %[[SOURCE_MATERIALIZE:.*]] = vector.shape_cast %[[XEVM_VECTOR]] : vector<16xf16> to vector<1x16xf16>
+ // CHECK: %[[XEGPU_VECTOR:.*]] = arith.addf %[[SOURCE_MATERIALIZE]], %[[SOURCE_MATERIALIZE]] : vector<1x16xf16>
+ // CHECK: %[[TARGET_MATERIALIZE:.*]] = vector.shape_cast %[[XEGPU_VECTOR]] : vector<1x16xf16> to vector<16xf16>
+ // CHECK: llvm.store %[[TARGET_MATERIALIZE]], %{{.*}} : vector<16xf16>, !llvm.ptr<3>
+ %loaded = xegpu.load_matrix %matrixdesc[16,0] : !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>> -> vector<1x16xf16>
+ %loaded_2 = arith.addf %loaded, %loaded : vector<1x16xf16>
+ xegpu.store_matrix %loaded_2, %matrixdesc[16,0] : vector<1x16xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
+ gpu.return
+ }
}
More information about the Mlir-commits
mailing list