[Mlir-commits] [mlir] 66134e7 - [MLIR][XeVM] Improve matrix ops lowering (#170268)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 8 09:29:19 PST 2025


Author: Artem Kroviakov
Date: 2025-12-08T18:29:15+01:00
New Revision: 66134e7136154cbb839484afd8f356cc4ed4a021

URL: https://github.com/llvm/llvm-project/commit/66134e7136154cbb839484afd8f356cc4ed4a021
DIFF: https://github.com/llvm/llvm-project/commit/66134e7136154cbb839484afd8f356cc4ed4a021.diff

LOG: [MLIR][XeVM] Improve matrix ops lowering (#170268)

Added: 
    

Modified: 
    mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
    mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 2b162ec3f3bf4..54254be007788 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -676,16 +676,25 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
     Value baseAddr32 = adaptor.getMemDesc();
     Value mdescVal = op.getMemDesc();
     // Load result or Store value Type can be vector or scalar.
-    Value data;
-    if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>)
-      data = op.getResult();
-    else
-      data = adaptor.getData();
-    VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType());
+    Type dataTy;
+    if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
+      Type resType = op.getResult().getType();
+      // Some transforms may leave unit dimension in the 2D vector, adaptors do
+      // not catch it for results.
+      if (auto vecType = dyn_cast<VectorType>(resType)) {
+        auto nonUnitDims = llvm::count_if(vecType.getShape(),
+                                          [](int64_t d) { return d != 1; });
+        assert(nonUnitDims <= 1 &&
+               "Expected either 1D vector or nD with unit dimensions");
+        resType = VectorType::get({vecType.getNumElements()},
+                                  vecType.getElementType());
+      }
+      dataTy = resType;
+    } else
+      dataTy = adaptor.getData().getType();
+    VectorType valOrResVecTy = dyn_cast<VectorType>(dataTy);
     if (!valOrResVecTy)
-      valOrResVecTy = VectorType::get(1, data.getType());
-    if (valOrResVecTy.getShape().size() != 1)
-      return rewriter.notifyMatchFailure(op, "Expected 1D data vector.");
+      valOrResVecTy = VectorType::get(1, dataTy);
 
     int64_t elemBitWidth =
         valOrResVecTy.getElementType().getIntOrFloatBitWidth();
@@ -1176,6 +1185,7 @@ struct ConvertXeGPUToXeVMPass
     };
     typeConverter.addSourceMaterialization(
         singleElementVectorMaterializationCast);
+    typeConverter.addSourceMaterialization(vectorMaterializationCast);
     typeConverter.addTargetMaterialization(memrefMaterializationCast);
     typeConverter.addTargetMaterialization(ui32MaterializationCast);
     typeConverter.addTargetMaterialization(ui64MaterializationCast);

diff  --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
index ac95a1a5707ea..3a3769f3a4f70 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
@@ -30,7 +30,7 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
   gpu.func @load_store_matrix_plain_2d_input(%arg0: memref<8192xi8, 3>) -> f32 {
     %c0 = arith.constant 0 : index
     %view = memref.view %arg0[%c0][]: memref<8192xi8, 3> to memref<64x32xf32, 3>
-    
+
     %subview = memref.subview %view[32, 0] [32, 32] [1, 1] : memref<64x32xf32, 3> to memref<32x32xf32, strided<[32, 1], offset: 1024>, 3>
 
     %0 = xegpu.create_mem_desc %subview : memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> -> !xegpu.mem_desc<32x32xf32>
@@ -43,7 +43,7 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
     //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32
 
     %tid_x = gpu.thread_id x
- 
+
     %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> f32
 
     //CHECK: llvm.store {{.*}}, {{.*}} : f32, !llvm.ptr<3>
@@ -81,15 +81,15 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
     //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
 
     //CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> f16
- 
+
 
     %tid_x = gpu.thread_id x
     %c13 = arith.constant 13 : index
     %1 = xegpu.load_matrix %0[%c13, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> f16
 
     //CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3>
-   
-    xegpu.store_matrix %1, %0[%c13, %tid_x]: f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index 
+
+    xegpu.store_matrix %1, %0[%c13, %tid_x]: f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index
     gpu.return %1: f16
   }
 
@@ -102,12 +102,12 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
     //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index
     //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
     %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
-    
+
     //CHECK: %[[tid_x:.*]] = gpu.thread_id x
     //CHECK: %[[c19:.*]] = arith.constant 19 : index
     %tid_x = gpu.thread_id x
     %c19 = arith.constant 19: index
-    
+
     //CHECK: %[[c16:.*]] = arith.constant 16 : index
     //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index
     //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index
@@ -127,10 +127,10 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
     //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
     //CHECK: %[[loaded:.*]] = llvm.load {{.*}} : !llvm.ptr<3> -> f16
     %1 = xegpu.load_matrix %0[%c19, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> f16
-    
+
     //CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3>
     xegpu.store_matrix %1, %0[%c19, %tid_x]:  f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index
-    
+
     //CHECK: gpu.return %[[loaded]] : f16
     gpu.return %1: f16
   }
@@ -161,7 +161,7 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
     //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
 
     //CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> vector<8xf16>
-     
+
     %tid_x = gpu.thread_id x
     %c16 = arith.constant 16 : index
     %1 = xegpu.load_matrix %0[%c16, %tid_x] : !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> vector<8xf16>
@@ -172,7 +172,7 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
     gpu.return %1: vector<8xf16>
   }
 
- 
+
   // e.g. for mem_desc<32x64xf16, @block=[16, 16]>
   // its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
   //CHECK-LABEL: load_store_matrix_blocked_subgroupblockio
@@ -214,11 +214,22 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
     %1 = xegpu.load_matrix %0[%c16, %c48] {subgroup_block_io}: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> vector<8xf16>
 
     //CHECK: %[[storeDataI16:.*]] = vector.bitcast %[[loaded]] : vector<8xf16> to vector<8xi16>
-    //CHECK: xevm.blockstore %[[ptr]], %[[storeDataI16]] : (!llvm.ptr<3>, vector<8xi16>) 
+    //CHECK: xevm.blockstore %[[ptr]], %[[storeDataI16]] : (!llvm.ptr<3>, vector<8xi16>)
 
     xegpu.store_matrix %1, %0[%c16, %c48] {subgroup_block_io}: vector<8xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index
 
     gpu.return %1: vector<8xf16>
   }
 
+  gpu.func @matrix_vector_materialization(%matrixdesc : !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>) {
+    // CHECK: %[[XEVM_VECTOR:.*]] = llvm.load %{{.*}} : !llvm.ptr<3> -> vector<16xf16>
+    // CHECK: %[[SOURCE_MATERIALIZE:.*]] = vector.shape_cast %[[XEVM_VECTOR]] : vector<16xf16> to vector<1x16xf16>
+    // CHECK: %[[XEGPU_VECTOR:.*]] = arith.addf %[[SOURCE_MATERIALIZE]], %[[SOURCE_MATERIALIZE]] : vector<1x16xf16>
+    // CHECK: %[[TARGET_MATERIALIZE:.*]] = vector.shape_cast %[[XEGPU_VECTOR]] : vector<1x16xf16> to vector<16xf16>
+    // CHECK: llvm.store %[[TARGET_MATERIALIZE]], %{{.*}} : vector<16xf16>, !llvm.ptr<3>
+    %loaded = xegpu.load_matrix %matrixdesc[16,0] : !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>> -> vector<1x16xf16>
+    %loaded_2 = arith.addf %loaded, %loaded : vector<1x16xf16>
+    xegpu.store_matrix %loaded_2, %matrixdesc[16,0] : vector<1x16xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
+    gpu.return
+  }
 }


        


More information about the Mlir-commits mailing list