[Mlir-commits] [mlir] [MLIR][XeGPU] Allow create mem desc from 2d memref (PR #167767)

Jianhui Li llvmlistbot at llvm.org
Tue Nov 18 11:07:40 PST 2025


https://github.com/Jianhui-Li updated https://github.com/llvm/llvm-project/pull/167767

>From d9a53760818ecb943e3f1d006a52ceb3cbc1703c Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 11 Nov 2025 05:48:52 +0000
Subject: [PATCH 1/4] inital implementation

---
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td |  2 +-
 .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp    | 45 +++++++++--------
 .../XeGPUToXeVM/loadstore_matrix.mlir         | 48 +++++++++++++++----
 mlir/test/Dialect/XeGPU/ops.mlir              | 21 ++++++++
 4 files changed, 86 insertions(+), 30 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 689ebd0d1179a..59dc0eade0744 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -1308,7 +1308,7 @@ def XeGPU_CreateMemDescOp: XeGPU_Op<"create_mem_desc", [Pure,
     Results:
      - `mem_desc` : the memory descriptor.
   }];
-  let arguments = (ins StaticShared1DMemRefOf<[I8]>:$source);
+  let arguments = (ins AnyTypeOf<[StaticShared1DMemRefOf<[I8]>, ConfinedType<MemRefRankOf<[XeGPU_ScalarType], [2]>, [HasStaticShapePred, isSharedPred]>]>:$source);
   let results = (outs XeGPU_MemDesc:$mem_desc);
   let assemblyFormat = "$source prop-dict attr-dict `` `:` type($source) `->` qualified(type($mem_desc))";
 }
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 705298f497d20..02b60ce5f6a1d 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -590,16 +590,22 @@ class CreateMemDescOpPattern final
   matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
-    auto resTy = op.getMemDesc();
+//    auto resTy = op.getMemDesc();
 
     // Create the result MemRefType with the same shape, element type, and
     // memory space
-    auto newResTy = getTypeConverter()->convertType<MemRefType>(resTy);
+//    auto newResTy = getTypeConverter()->convertType<MemRefType>(resTy);
 
-    Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
-    auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy,
-                                         op.getSource(), zero, ValueRange());
-    rewriter.replaceOp(op, viewOp);
+//    Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
+//    auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy,
+//                                         op.getSource(), zero, ValueRange());
+
+    Value baseAddr = memref::ExtractAlignedPointerAsIndexOp::create(
+                      rewriter, op.getLoc(), op.getSource());
+    auto  baseAddr32 = arith::IndexCastUIOp::create(
+                      rewriter, op.getLoc(), rewriter.getI32Type(), baseAddr);
+
+    rewriter.replaceOp(op, baseAddr32);
     return success();
   }
 };
@@ -619,7 +625,7 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
 
     auto loc = op.getLoc();
     auto ctxt = rewriter.getContext();
-    Value basePtrStruct = adaptor.getMemDesc();
+    Value baseAddr32 = adaptor.getMemDesc();
     Value mdescVal = op.getMemDesc();
     // Load result or Store value Type can be vector or scalar.
     Value data;
@@ -647,21 +653,21 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
 
     auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType());
 
-    Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create(
-        rewriter, loc, basePtrStruct);
+    // Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create(
+    //    rewriter, loc, basePtrStruct);
 
     // Convert base pointer (ptr) to i32
-    Value basePtrI32 = arith::IndexCastUIOp::create(
-        rewriter, loc, rewriter.getI32Type(), basePtrLLVM);
+    //Value basePtrI32 = arith::IndexCastUIOp::create(
+    //    rewriter, loc, rewriter.getI32Type(), baseAddr);
 
     Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
     linearOffset = arith::IndexCastUIOp::create(
         rewriter, loc, rewriter.getI32Type(), linearOffset);
-    basePtrI32 = addOffsetToBaseAddr(rewriter, loc, basePtrI32, linearOffset,
+    Value basePtrI32 = addOffsetToBaseAddr(rewriter, loc, baseAddr32, linearOffset,
                                      elemByteSize);
 
     // convert base pointer (i32) to LLVM pointer type
-    basePtrLLVM =
+    Value basePtrLLVM =
         LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32);
 
     if (op.getSubgroupBlockIoAttr()) {
@@ -1005,12 +1011,13 @@ struct ConvertXeGPUToXeVMPass
       auto i32Type = IntegerType::get(&getContext(), 32);
       return VectorType::get(8, i32Type);
     });
-    // Convert MemDescType into flattened MemRefType for SLM
-    typeConverter.addConversion([&](xegpu::MemDescType type) -> Type {
-      Type elemTy = type.getElementType();
-      int numElems = type.getNumElements();
-      return MemRefType::get(numElems, elemTy, AffineMap(), 3);
-    });
+    // Convert MemDescType into i32 for SLM
+     typeConverter.addConversion([&](xegpu::MemDescType type) -> Type {
+    //  Type elemTy = type.getElementType();
+    //  int numElems = type.getNumElements();
+    //  return MemRefType::get(numElems, elemTy, AffineMap(), 3);
+        return IntegerType::get(&getContext(), 32);
+     });
 
     typeConverter.addConversion([&](MemRefType type) -> Type {
       // Convert MemRefType to i64 type.
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
index d4cb493271d0d..2b5c308f9ab86 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
@@ -4,8 +4,8 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
 
  // e.g. for mem_desc<32x32xf16, @strides=[1, 16]>
   // its memory layout tuple is (blocked shape = [1,1,32,32],strides=[1024,1024,32,1])
-  //CHECK-LABEL: load_store_matrix_1
-  gpu.func @load_store_matrix_1(%arg0: memref<4096xi8, 3>) -> f32 {
+  //CHECK-LABEL: load_store_matrix_plain
+  gpu.func @load_store_matrix_plain(%arg0: memref<4096xi8, 3>) -> f32 {
     %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32>
 
     //CHECK: %[[TID:.*]] = gpu.thread_id x
@@ -26,10 +26,38 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
     gpu.return %1: f32
   }
 
+  //CHECK-LABEL: load_store_matrix_plain_2d_input
+  gpu.func @load_store_matrix_plain(%arg0: memref<8192xi8, 3>) -> f32 {
+
+    %view = memref.view %arg0[0][]: memref<8192xi8, 3> to memref<64x32xf32, 3>
+    
+    %subview = memref.subview %view[64, 0] [64, 128] [1, 1] : memref<64x32xf32, 3> to memref<32x32xf32, strided<[32, 1], offset: 1024>, 3>
+
+    %0 = xegpu.create_mem_desc %subview : memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> -> !xegpu.mem_desc<32x32xf32>
+
+    //CHECK: %[[TID:.*]] = gpu.thread_id x
+    //CHECK: %[[C1:.*]] = arith.constant 1 : index
+    //CHECK: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index
+    //CHECK: %[[C4:.*]] = arith.constant 4 : i32
+    //CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4]] : i32
+    //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32
+
+    %tid_x = gpu.thread_id x
+    %c0 = arith.constant 0 : index
+    %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> f32
+
+    //CHECK: llvm.store {{.*}}, {{.*}} : f32, !llvm.ptr<3>
+
+     xegpu.store_matrix %1, %0[%c0, %tid_x]: f32, !xegpu.mem_desc<32x32xf32>, index, index
+
+    gpu.return %1: f32
+  }
+
+
 // e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 32]>
   // its memory layout tuple is ([2,4,16,16],[256,512,1,16])
-  //CHECK-LABEL: load_store_matrix_2
-  gpu.func @load_store_matrix_2(%arg0: memref<4096xi8, 3>) -> f16 {
+  //CHECK-LABEL: load_store_matrix_blocked_strided
+  gpu.func @load_store_matrix_blocked_strided(%arg0: memref<4096xi8, 3>) -> f16 {
     %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
     //CHECK: %[[c0:.*]] = arith.constant 0 : index
     //CHECK: %[[tid_x:.*]] = gpu.thread_id x
@@ -68,8 +96,8 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
 
   // e.g. for mem_desc<32x64xf16, @block=[16, 16]>
   // its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
-  //CHECK-LABEL: load_store_matrix_3
-  gpu.func @load_store_matrix_3(%arg0: memref<4096xi8, 3>) -> f16 {
+  //CHECK-LABEL: load_store_matrix_blocked_nostride
+  gpu.func @load_store_matrix_blocked_nostride(%arg0: memref<4096xi8, 3>) -> f16 {
     //CHECK: %[[c0:.*]] = arith.constant 0 : index
     //CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3>
     %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
@@ -110,8 +138,8 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
 
    // e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 16]>
   // its memory layout tuple is ([2,4,16,16],[256,512,1,16])
-  //CHECK-LABEL: load_store_matrix_4
-  gpu.func @load_store_matrix_4(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
+  //CHECK-LABEL: load_store_matrix_blocked_strided_return_vector
+  gpu.func @load_store_matrix_blocked_strided_return_vector(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
     %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
 
     //CHECK: %[[c0:.*]] = arith.constant 0 : index
@@ -150,8 +178,8 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
  
   // e.g. for mem_desc<32x64xf16, @block=[16, 16]>
   // its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
-  //CHECK-LABEL: load_store_matrix_5
-  gpu.func @load_store_matrix_5(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
+  //CHECK-LABEL: load_store_matrix_blocked_subgroupblockio
+  gpu.func @load_store_matrix_blocked_subgroupblockio(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
     //CHECK: %[[c0:.*]] = arith.constant 0 : index
     //CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3>
  
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 9b3829664108d..1e9738f44bb66 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -834,6 +834,27 @@ gpu.func @create_mem_desc_with_stride() {
   gpu.return
 }
 
+
+// CHECK-LABEL: gpu.func @create_mem_desc_from_2d_memref({{.*}}) {
+gpu.func @create_mem_desc_from_2d_memref() {
+  //CHECK: [[alloc:%.+]] = memref.alloca() {alignment = 1024 : i64} : memref<16x64xf16, 3>
+  //CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[alloc]] : memref<16x64xf16, 3> -> !xegpu.mem_desc<16x64xf16>
+  %m = memref.alloca() {alignment = 1024} : memref<16x64xf16, 3>
+  %mem_desc = xegpu.create_mem_desc %m : memref<16x64xf16, 3> -> !xegpu.mem_desc<16x64xf16>
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @create_mem_desc_with_stride_from_2d_memref({{.*}}) {
+gpu.func @create_mem_desc_with_stride_from_2d_memref() {
+  //CHECK: %[[ALLOC:.+]] = memref.alloca() {alignment = 1024 : i64} : memref<32x64xf16, 3>
+  //CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][16, 0] [16, 64] [1, 1] : memref<32x64xf16, 3> to memref<16x64xf16, strided<[64, 1], offset: 1024>, 3>
+  //CHECK: %{{.+}} = xegpu.create_mem_desc %[[SUBVIEW]] : memref<16x64xf16, strided<[64, 1], offset: 1024>, 3> -> !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
+  %m = memref.alloca() {alignment = 1024} : memref<32x64xf16, 3>
+  %m_sub = memref.subview %m[16, 0][16, 64][1,1] : memref<32x64xf16, 3> to memref<16x64xf16, strided<[64, 1], offset: 1024>, 3>
+  %mem_desc = xegpu.create_mem_desc %m_sub : memref<16x64xf16, strided<[64, 1], offset: 1024>, 3> -> !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
+  gpu.return
+}
+
 // CHECK: gpu.func @load_matrix([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>)
 gpu.func @load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {
   // CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> vector<8x16xf16>

>From 18f8cb976d61b26e37e9b00ad6b9e770c288e73f Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 12 Nov 2025 21:41:09 +0000
Subject: [PATCH 2/4] adding tests

---
 .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp    | 36 ++++------------
 .../XeGPUToXeVM/loadstore_matrix.mlir         | 42 +++++++++----------
 2 files changed, 27 insertions(+), 51 deletions(-)

diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 02b60ce5f6a1d..c01078ca6938e 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -590,20 +590,10 @@ class CreateMemDescOpPattern final
   matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
-//    auto resTy = op.getMemDesc();
-
-    // Create the result MemRefType with the same shape, element type, and
-    // memory space
-//    auto newResTy = getTypeConverter()->convertType<MemRefType>(resTy);
-
-//    Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
-//    auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy,
-//                                         op.getSource(), zero, ValueRange());
-
     Value baseAddr = memref::ExtractAlignedPointerAsIndexOp::create(
-                      rewriter, op.getLoc(), op.getSource());
-    auto  baseAddr32 = arith::IndexCastUIOp::create(
-                      rewriter, op.getLoc(), rewriter.getI32Type(), baseAddr);
+        rewriter, op.getLoc(), op.getSource());
+    auto baseAddr32 = arith::IndexCastUIOp::create(
+        rewriter, op.getLoc(), rewriter.getI32Type(), baseAddr);
 
     rewriter.replaceOp(op, baseAddr32);
     return success();
@@ -653,18 +643,11 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
 
     auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType());
 
-    // Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create(
-    //    rewriter, loc, basePtrStruct);
-
-    // Convert base pointer (ptr) to i32
-    //Value basePtrI32 = arith::IndexCastUIOp::create(
-    //    rewriter, loc, rewriter.getI32Type(), baseAddr);
-
     Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
     linearOffset = arith::IndexCastUIOp::create(
         rewriter, loc, rewriter.getI32Type(), linearOffset);
-    Value basePtrI32 = addOffsetToBaseAddr(rewriter, loc, baseAddr32, linearOffset,
-                                     elemByteSize);
+    Value basePtrI32 = addOffsetToBaseAddr(rewriter, loc, baseAddr32,
+                                           linearOffset, elemByteSize);
 
     // convert base pointer (i32) to LLVM pointer type
     Value basePtrLLVM =
@@ -1012,12 +995,9 @@ struct ConvertXeGPUToXeVMPass
       return VectorType::get(8, i32Type);
     });
     // Convert MemDescType into i32 for SLM
-     typeConverter.addConversion([&](xegpu::MemDescType type) -> Type {
-    //  Type elemTy = type.getElementType();
-    //  int numElems = type.getNumElements();
-    //  return MemRefType::get(numElems, elemTy, AffineMap(), 3);
-        return IntegerType::get(&getContext(), 32);
-     });
+    typeConverter.addConversion([&](xegpu::MemDescType type) -> Type {
+      return IntegerType::get(&getContext(), 32);
+    });
 
     typeConverter.addConversion([&](MemRefType type) -> Type {
       // Convert MemRefType to i64 type.
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
index 2b5c308f9ab86..5b3776ba0039c 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
@@ -27,11 +27,11 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
   }
 
   //CHECK-LABEL: load_store_matrix_plain_2d_input
-  gpu.func @load_store_matrix_plain(%arg0: memref<8192xi8, 3>) -> f32 {
-
-    %view = memref.view %arg0[0][]: memref<8192xi8, 3> to memref<64x32xf32, 3>
+  gpu.func @load_store_matrix_plain_2d_input(%arg0: memref<8192xi8, 3>) -> f32 {
+    %c0 = arith.constant 0 : index
+    %view = memref.view %arg0[%c0][]: memref<8192xi8, 3> to memref<64x32xf32, 3>
     
-    %subview = memref.subview %view[64, 0] [64, 128] [1, 1] : memref<64x32xf32, 3> to memref<32x32xf32, strided<[32, 1], offset: 1024>, 3>
+    %subview = memref.subview %view[32, 0] [32, 32] [1, 1] : memref<64x32xf32, 3> to memref<32x32xf32, strided<[32, 1], offset: 1024>, 3>
 
     %0 = xegpu.create_mem_desc %subview : memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> -> !xegpu.mem_desc<32x32xf32>
 
@@ -43,7 +43,7 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
     //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32
 
     %tid_x = gpu.thread_id x
-    %c0 = arith.constant 0 : index
+ 
     %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> f32
 
     //CHECK: llvm.store {{.*}}, {{.*}} : f32, !llvm.ptr<3>
@@ -59,7 +59,7 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
   //CHECK-LABEL: load_store_matrix_blocked_strided
   gpu.func @load_store_matrix_blocked_strided(%arg0: memref<4096xi8, 3>) -> f16 {
     %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
-    //CHECK: %[[c0:.*]] = arith.constant 0 : index
+
     //CHECK: %[[tid_x:.*]] = gpu.thread_id x
     //CHECK: %[[c13:.*]] = arith.constant 13 : index
     //CHECK: %[[c16:.*]] = arith.constant 16 : index
@@ -67,7 +67,7 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
     //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c13]], %[[c16]] : index
     //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
     //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
-
+    //CHECK: %[[c0:.*]] = arith.constant 0 : index
     //CHECK: %[[c256:.*]] = arith.constant 256 : index
     //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
     //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
@@ -98,8 +98,9 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
   // its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
   //CHECK-LABEL: load_store_matrix_blocked_nostride
   gpu.func @load_store_matrix_blocked_nostride(%arg0: memref<4096xi8, 3>) -> f16 {
-    //CHECK: %[[c0:.*]] = arith.constant 0 : index
-    //CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3>
+
+    //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index
+    //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
     %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
     
     //CHECK: %[[tid_x:.*]] = gpu.thread_id x
@@ -107,13 +108,12 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
     %tid_x = gpu.thread_id x
     %c19 = arith.constant 19: index
     
-    //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index
-    //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
     //CHECK: %[[c16:.*]] = arith.constant 16 : index
     //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index
     //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index
     //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
     //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
+    //CHECK: %[[c0:.*]] = arith.constant 0 : index
     //CHECK: %[[c1024:.*]] = arith.constant 1024 : index
     //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c1024]] : index
     //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
@@ -125,7 +125,6 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
     //CHECK: %[[c1:.*]] = arith.constant 1 : index
     //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c1]] : index
     //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
-
     //CHECK: %[[loaded:.*]] = llvm.load {{.*}} : !llvm.ptr<3> -> f16
     %1 = xegpu.load_matrix %0[%c19, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> f16
     
@@ -142,15 +141,13 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
   gpu.func @load_store_matrix_blocked_strided_return_vector(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
     %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
 
-    //CHECK: %[[c0:.*]] = arith.constant 0 : index
     //CHECK: %[[tid_x:.*]] = gpu.thread_id x
-
     //CHECK: %[[c16:.*]] = arith.constant 16 : index
     //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
     //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
     //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
     //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
-
+    //CHECK: %[[c0:.*]] = arith.constant 0 : index
     //CHECK: %[[c256:.*]] = arith.constant 256 : index
     //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
     //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
@@ -180,23 +177,22 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
   // its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
   //CHECK-LABEL: load_store_matrix_blocked_subgroupblockio
   gpu.func @load_store_matrix_blocked_subgroupblockio(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
-    //CHECK: %[[c0:.*]] = arith.constant 0 : index
-    //CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3>
- 
-    %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
- 
+
+    //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index
+    //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
+     %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
+    
+
     //CHECK: %[[c16:.*]] = arith.constant 16 : index
     //CHECK: %[[c48:.*]] = arith.constant 48 : index
-  
     %c16 = arith.constant 16 : index
     %c48 = arith.constant 48 : index
 
-    //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index
-    //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
     //CHECK: %[[offset0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
     //CHECK: %[[offset1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
     //CHECK: %[[offset2:.*]] = arith.divsi %[[c48]], %[[c16]] : index
     //CHECK: %[[offset3:.*]] = arith.remsi %[[c48]], %[[c16]] : index
+    //CHECK: %[[c0:.*]] = arith.constant 0 : index
     //CHECK: %[[c1024:.*]] = arith.constant 1024 : index
     //CHECK: %[[mul0:.*]] = arith.muli %[[offset0]], %[[c1024]] : index
     //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index

>From c0ba80e50f65af0fe7daa9e91afdfc271326665e Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 12 Nov 2025 22:05:35 +0000
Subject: [PATCH 3/4] fixing documentation

---
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 59dc0eade0744..cd5824b607aba 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -1304,7 +1304,7 @@ def XeGPU_CreateMemDescOp: XeGPU_Op<"create_mem_desc", [Pure,
     as the underlying shared local memory.
 
     Arguments:
-     - `source` : a 1D statically shaped memref with element type i8, representing the raw SLM buffer.
+     - `source` : 1D or 2D statically shape memref, representing the raw SLM buffer. When the source is provided as 1D memref, its type must be i8. 
     Results:
      - `mem_desc` : the memory descriptor.
   }];

>From 281393e862ecea749044b111d61589f2f02b3d4c Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 18 Nov 2025 19:07:17 +0000
Subject: [PATCH 4/4] address feedback

---
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td   |  8 +-------
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 11 +++++++++++
 mlir/test/Dialect/XeGPU/invalid.mlir             |  2 +-
 3 files changed, 13 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 7ac5f495a9789..ed166260fb689 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -1282,12 +1282,6 @@ def XeGPU_ConvertLayoutOp: XeGPU_Op<"convert_layout", [Pure, AllTypesMatch<["sou
     let hasCanonicalizer = 1;
 }
 
-def isSharedPred : CPred<"isSharedMemory(llvm::cast<mlir::MemRefType>($_self))">;
-class StaticShared1DMemRefOf<list<Type> allowedTypes> :
-  ConfinedType<MemRefRankOf<allowedTypes, [1]>, [HasStaticShapePred, isSharedPred],
-     "statically shaped " # MemRefOf<allowedTypes>.summary # " for shared memory",
-     "mlir::MemRefType">;
-
 class SizeInBits<string name> :
   StrFunc<"llvm::cast<mlir::ShapedType>($" # name # ".getType()).getNumElements()"
           "*llvm::cast<mlir::ShapedType>($" # name # ".getType()).getElementTypeBitWidth()">;
@@ -1308,7 +1302,7 @@ def XeGPU_CreateMemDescOp: XeGPU_Op<"create_mem_desc", [Pure,
     Results:
      - `mem_desc` : the memory descriptor.
   }];
-  let arguments = (ins AnyTypeOf<[StaticShared1DMemRefOf<[I8]>, ConfinedType<MemRefRankOf<[XeGPU_ScalarType], [2]>, [HasStaticShapePred, isSharedPred]>]>:$source);
+  let arguments = (ins AnyTypeOf<[StaticShared1DMemRefOf<[XeGPU_ScalarType]>, StaticShared2DMemRefOf<[XeGPU_ScalarType]>]>:$source);
   let results = (outs XeGPU_MemDesc:$mem_desc);
   let assemblyFormat = "$source prop-dict attr-dict `` `:` type($source) `->` qualified(type($mem_desc))";
 }
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index b1196fbe9c66a..7f7f7d065c50e 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -35,6 +35,17 @@ class XeGPUTypeDef<string name, string typeMnemonic, list<Trait> traits = [],
   let mnemonic = typeMnemonic;
 }
 
+def isSharedPred : CPred<"isSharedMemory(llvm::cast<mlir::MemRefType>($_self))">;
+class StaticShared1DMemRefOf<list<Type> allowedTypes> :
+  ConfinedType<MemRefRankOf<allowedTypes, [1]>, [HasStaticShapePred, isSharedPred],
+     "reside in share memory and statically 1d shaped " # MemRefOf<allowedTypes>.summary # " ",
+     "mlir::MemRefType">;
+
+class StaticShared2DMemRefOf<list<Type> allowedTypes>:
+  ConfinedType<MemRefRankOf<allowedTypes, [2]>, [HasStaticShapePred, isSharedPred],
+     "reside in share memory and statically 2d shaped " # MemRefOf<allowedTypes>.summary # " ",
+     "mlir::MemRefType">;
+
 def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
         [ShapedTypeInterface], "::mlir::TensorType"> {
   let summary = "TensorDesc describing regions of interested data.";
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 92f353717ac59..67faa60f2835e 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -836,7 +836,7 @@ func.func @slice_attr_repeat_dim() {
 // -----
 func.func @create_mem_desc_non_slm() {
   %m = memref.alloca() {alignment = 1024} : memref<2048xi8, 1>
-  // expected-error at +1 {{operand #0 must be statically shaped memref of 8-bit signless integer values for shared memory}}
+  // expected-error at +1 {{operand #0 must be reside in share memory and statically 1d shaped memref }}
   %mem_desc = xegpu.create_mem_desc %m : memref<2048xi8, 1> -> !xegpu.mem_desc<16x64xf16>
   return
 }



More information about the Mlir-commits mailing list