[Mlir-commits] [mlir] [MLIR][XeGPU] Support subview memref: handling the base address during xegpu to xevm type conversion (PR #170541)

Jianhui Li llvmlistbot at llvm.org
Wed Dec 3 13:49:32 PST 2025


https://github.com/Jianhui-Li updated https://github.com/llvm/llvm-project/pull/170541

>From 43d9ddb971ed0d80a44ef95d6d7751c5996337bb Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 2 Dec 2025 00:13:27 +0000
Subject: [PATCH 1/2] support memref subview in xegpu to xevm type conversion

---
 .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp    | 43 ++++++++++++++++---
 .../XeGPUToXeVM/loadstore_matrix.mlir         | 22 ++++++----
 2 files changed, 52 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 7f1ec17ce0ae8..bafd1dc348e5b 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -999,7 +999,7 @@ struct ConvertXeGPUToXeVMPass
     // LLVM type converter puts unrealized casts for the following cases:
     // add materialization casts to handle them.
 
-    // Materialization to convert memref to i64
+    // Materialization to convert memref to i64 or i32 depending on global/SLM
     auto memrefMaterializationCast = [](OpBuilder &builder, Type type,
                                         ValueRange inputs,
                                         Location loc) -> Value {
@@ -1007,11 +1007,44 @@ struct ConvertXeGPUToXeVMPass
         return {};
       auto input = inputs.front();
       if (auto memrefTy = dyn_cast<MemRefType>(input.getType())) {
+        unsigned rank = memrefTy.getRank();
+        Type indexType = builder.getIndexType();
 
-        Value addr =
-            memref::ExtractAlignedPointerAsIndexOp::create(builder, loc, input);
-        return arith::IndexCastUIOp::create(builder, loc, type, addr)
-            .getResult();
+        SmallVector<Type> resultTypes;
+        // Result types: [base_memref, offset, stride0, stride1, ..., strideN-1,
+        // size0, size1, ..., sizeN-1]
+        resultTypes.push_back(MemRefType::get(
+            {}, memrefTy.getElementType(), MemRefLayoutAttrInterface(),
+            memrefTy.getMemorySpace()));  // base memref (unranked)
+        resultTypes.push_back(indexType); // offset
+        for (unsigned i = 0; i < rank; ++i)
+          resultTypes.push_back(indexType); // strides
+        for (unsigned i = 0; i < rank; ++i)
+          resultTypes.push_back(indexType); // sizes
+
+        auto meta = memref::ExtractStridedMetadataOp::create(
+            builder, loc, resultTypes, input);
+
+        auto addr = memref::ExtractAlignedPointerAsIndexOp::create(
+            builder, loc, meta.getBaseBuffer());
+        auto offset = meta.getOffset();
+
+        auto addr_casted =
+            arith::IndexCastUIOp::create(builder, loc, type, addr);
+        auto offset_casted =
+            arith::IndexCastUIOp::create(builder, loc, type, offset);
+
+        // Compute the final address: base address + byte offset
+        auto byte_size = arith::ConstantOp::create(
+            builder, loc, type,
+            builder.getIntegerAttr(type,
+                                   memrefTy.getElementTypeBitWidth() / 8));
+        auto byte_offset =
+            arith::MulIOp::create(builder, loc, offset_casted, byte_size);
+        auto addr_with_offset =
+            arith::AddIOp::create(builder, loc, addr_casted, byte_offset);
+
+        return addr_with_offset.getResult();
       }
       return {};
     };
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
index ac95a1a5707ea..aba73b80f0439 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
@@ -33,22 +33,29 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
     
     %subview = memref.subview %view[32, 0] [32, 32] [1, 1] : memref<64x32xf32, 3> to memref<32x32xf32, strided<[32, 1], offset: 1024>, 3>
 
+    //CHECK: %[[base_buffer:.*]], %[[offset:.*]], %[[sizes:.*]]:2, %[[strides:.*]]:2 = memref.extract_strided_metadata %{{.*}} : memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> -> memref<f32, 3>, index, index, index, index, index
+    //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[base_buffer]] : memref<f32, 3> -> index
+    //CHECK: %[[ptr_i32:.*]] = arith.index_castui %[[intptr]] : index to i32
+    //CHECK: %[[offset_i32:.*]] = arith.index_castui %[[offset]] : index to i32
+    //CHECK: %[[c4_i32:.*]] = arith.constant 4 : i32
+    //CHECK: %[[mul:.*]] = arith.muli %[[offset_i32]], %[[c4_i32]] : i32
+    //CHECK: %[[add:.*]] = arith.addi %[[ptr_i32]], %[[mul]] : i32
+
     %0 = xegpu.create_mem_desc %subview : memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> -> !xegpu.mem_desc<32x32xf32>
 
     //CHECK: %[[TID:.*]] = gpu.thread_id x
     //CHECK: %[[C1:.*]] = arith.constant 1 : index
     //CHECK: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index
-    //CHECK: %[[C4:.*]] = arith.constant 4 : i32
-    //CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4]] : i32
+     //CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, {{.*}} : i32
     //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32
 
     %tid_x = gpu.thread_id x
- 
+  
     %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> f32
 
     //CHECK: llvm.store {{.*}}, {{.*}} : f32, !llvm.ptr<3>
 
-     xegpu.store_matrix %1, %0[%c0, %tid_x]: f32, !xegpu.mem_desc<32x32xf32>, index, index
+    xegpu.store_matrix %1, %0[%c0, %tid_x]: f32, !xegpu.mem_desc<32x32xf32>, index, index
 
     gpu.return %1: f32
   }
@@ -99,8 +106,6 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
   //CHECK-LABEL: load_store_matrix_blocked_nostride
   gpu.func @load_store_matrix_blocked_nostride(%arg0: memref<4096xi8, 3>) -> f16 {
 
-    //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index
-    //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
     %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
     
     //CHECK: %[[tid_x:.*]] = gpu.thread_id x
@@ -178,7 +183,8 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
   //CHECK-LABEL: load_store_matrix_blocked_subgroupblockio
   gpu.func @load_store_matrix_blocked_subgroupblockio(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
 
-    //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index
+    //CHECK: %[[base_buffer:.*]], %[[offset:.*]], %[[sizes:.*]], %[[strides:.*]] = memref.extract_strided_metadata %arg0 : memref<4096xi8, 3> -> memref<i8, 3>, index, index, index
+    //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[base_buffer]] : memref<i8, 3> -> index
     //CHECK: %[[basePtrI32:.*]] = arith.index_castui %[[intptr]] : index to i32
      %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
 
@@ -206,7 +212,7 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
     //CHECK: %[[linearOffsetI64:.*]] = arith.index_castui %[[linearOffset]] : index to i32
     //CHECK: %[[c2:.*]] = arith.constant 2 : i32
     //CHECK: %[[byteOffset:.*]] = arith.muli %[[linearOffsetI64]], %[[c2]] : i32
-    //CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI32]], %[[byteOffset]] : i32
+    //CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI32:.*]], %[[byteOffset]] : i32
     //CHECK: %[[ptr:.*]] = llvm.inttoptr %[[finalPtr]] : i32 to !llvm.ptr<3>
     //CHECK: %[[loadedI16:.*]] = xevm.blockload %[[ptr]] : (!llvm.ptr<3>) -> vector<8xi16>
     //CHECK: %[[loaded:.*]] = vector.bitcast %[[loadedI16]] : vector<8xi16> to vector<8xf16>

>From 2664dc1c1bc88e531db8c5aa40fe1bfb3004b9d0 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 3 Dec 2025 19:21:54 +0000
Subject: [PATCH 2/2] add static offset support

---
 .../mlir/Dialect/XeGPU/Utils/XeGPUUtils.h     |   3 +
 .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp    |  52 +++--
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   |  12 ++
 .../XeGPUToXeVM/create_nd_tdesc.mlir          |   8 +-
 .../XeGPUToXeVM/loadstore_matrix.mlir         | 188 +++++++++---------
 .../Conversion/XeGPUToXeVM/loadstore_nd.mlir  |   4 +-
 6 files changed, 146 insertions(+), 121 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 58092c3bb9ed2..b5978dc8d7b74 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -175,6 +175,9 @@ template <typename T>
 int getLargestDivisor(T dim, ArrayRef<T> candidates,
                       ArrayRef<T> candidateMultiples = {});
 
+/// Checks if the given MemRefType refers to shared memory.
+bool isSharedMemRef(const MemRefType &memrefTy);
+
 } // namespace xegpu
 
 } // namespace mlir
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index bafd1dc348e5b..a1c2745864dc6 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -991,9 +991,8 @@ struct ConvertXeGPUToXeVMPass
     });
 
     typeConverter.addConversion([&](MemRefType type) -> Type {
-      if (type.getMemorySpaceAsInt() == 3)
-        return IntegerType::get(&getContext(), 32);
-      return IntegerType::get(&getContext(), 64);
+      return IntegerType::get(&getContext(),
+                              (xegpu::isSharedMemRef(type) ? 32 : 64));
     });
 
     // LLVM type converter puts unrealized casts for the following cases:
@@ -1010,24 +1009,35 @@ struct ConvertXeGPUToXeVMPass
         unsigned rank = memrefTy.getRank();
         Type indexType = builder.getIndexType();
 
-        SmallVector<Type> resultTypes;
-        // Result types: [base_memref, offset, stride0, stride1, ..., strideN-1,
-        // size0, size1, ..., sizeN-1]
-        resultTypes.push_back(MemRefType::get(
-            {}, memrefTy.getElementType(), MemRefLayoutAttrInterface(),
-            memrefTy.getMemorySpace()));  // base memref (unranked)
-        resultTypes.push_back(indexType); // offset
-        for (unsigned i = 0; i < rank; ++i)
-          resultTypes.push_back(indexType); // strides
-        for (unsigned i = 0; i < rank; ++i)
-          resultTypes.push_back(indexType); // sizes
-
-        auto meta = memref::ExtractStridedMetadataOp::create(
-            builder, loc, resultTypes, input);
-
-        auto addr = memref::ExtractAlignedPointerAsIndexOp::create(
-            builder, loc, meta.getBaseBuffer());
-        auto offset = meta.getOffset();
+        int64_t intOffsets;
+        SmallVector<int64_t> intStrides;
+        Value addr;
+        Value offset;
+        if (failed(memrefTy.getStridesAndOffset(intStrides, intOffsets))) {
+
+          // Result types: [base_memref, offset, stride0, stride1, ...,
+          // strideN-1, size0, size1, ..., sizeN-1]
+          SmallVector<Type> resultTypes{
+              MemRefType::get({}, memrefTy.getElementType(),
+                              MemRefLayoutAttrInterface(),
+                              memrefTy.getMemorySpace()),
+              indexType};
+          // strides + sizes
+          resultTypes.append(2 * rank, indexType);
+
+          auto meta = memref::ExtractStridedMetadataOp::create(
+              builder, loc, resultTypes, input);
+
+          addr = memref::ExtractAlignedPointerAsIndexOp::create(
+              builder, loc, meta.getBaseBuffer());
+          offset = meta.getOffset();
+
+        } else {
+          addr = memref::ExtractAlignedPointerAsIndexOp::create(builder, loc,
+                                                                input);
+          offset = arith::ConstantOp::create(builder, loc,
+                                             builder.getIndexAttr(intOffsets));
+        }
 
         auto addr_casted =
             arith::IndexCastUIOp::create(builder, loc, type, addr);
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 91432b1c11304..eecbb7b907e9f 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -580,3 +580,15 @@ template int xegpu::getLargestDivisor<int>(int dim, ArrayRef<int> candidates,
 template int
 xegpu::getLargestDivisor<unsigned>(unsigned dim, ArrayRef<unsigned> candidates,
                                    ArrayRef<unsigned> candidateMultiples);
+
+/// Checks if the given MemRefType refers to shared memory.
+bool xegpu::isSharedMemRef(const MemRefType &memrefTy) {
+  Attribute attr = memrefTy.getMemorySpace();
+  if (!attr)
+    return false;
+  if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
+    return intAttr.getInt() == static_cast<int>(xevm::AddrSpace::SHARED);
+  if (auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr))
+    return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
+  return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
+}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
index 8b87b791c9fd3..242101955b900 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
@@ -8,7 +8,8 @@ gpu.module @create_nd_tdesc {
   gpu.func @create_nd_tdesc(%src: memref<16x32xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index,
   %stride1: index, %stride2: index, %offset1: index, %offset2: index, %dyn: memref<?x?xf16>) kernel {
         // CHECK: %[[INTPTR_5:.*]] = memref.extract_aligned_pointer_as_index %[[DYN]] : memref<?x?xf16> -> index
-        // CHECK: %[[VAR23:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64
+        // CHECK: %[[BASE_ADDR3:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64    
+        
         // CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index
         // CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64
         // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
@@ -39,7 +40,7 @@ gpu.module @create_nd_tdesc {
         // CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64
         // CHECK: %[[SHAPE_H2:.*]] = arith.trunci %[[C16_I64]] : i64 to i32
         // CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64>
-        // CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64>
+        // CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2_OFFSET:.*]], %[[VAR14]] [0] : i64 into vector<4xi64>
         // CHECK: %[[VAR16:.*]] = vector.bitcast %[[VAR15]] : vector<4xi64> to vector<8xi32>
         // CHECK: %[[VAR17:.*]] = vector.insert %[[SHAPE_W2]], %[[VAR16]] [2] : i32 into vector<8xi32>
         // CHECK: %[[VAR18:.*]] = vector.insert %[[SHAPE_H2]], %[[VAR17]] [3] : i32 into vector<8xi32>
@@ -53,13 +54,14 @@ gpu.module @create_nd_tdesc {
         %size_x = arith.constant 64 : index
         // CHECK: %[[C16:.*]] = arith.constant 16 : index
         %BLOCK_DMODEL = arith.constant 16 : index
+   
         // CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32>
         // CHECK: %[[C0_I32_6:.*]] = arith.constant 0 : i32
         // CHECK: %[[C0_I32_7:.*]] = arith.constant 0 : i32
         // CHECK: %[[VAR21:.*]] = arith.index_cast %[[C16]] : index to i32
         // CHECK: %[[VAR22:.*]] = arith.index_cast %[[C64]] : index to i32
         // CHECK: %[[VAR24:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64>
-        // CHECK: %[[VAR25:.*]] = vector.insert %[[VAR23]], %[[VAR24]] [0] : i64 into vector<4xi64>
+        // CHECK: %[[VAR25:.*]] = vector.insert %[[BASE_ADDR3_OFFSET:.*]], %[[VAR24]] [0] : i64 into vector<4xi64>
         // CHECK: %[[VAR26:.*]] = vector.bitcast %[[VAR25]] : vector<4xi64> to vector<8xi32>
         // CHECK: %[[VAR27:.*]] = vector.insert %[[VAR21]], %[[VAR26]] [2] : i32 into vector<8xi32>
         // CHECK: %[[VAR28:.*]] = vector.insert %[[VAR22]], %[[VAR27]] [3] : i32 into vector<8xi32>
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
index aba73b80f0439..179fd397d7074 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
@@ -33,21 +33,20 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
     
     %subview = memref.subview %view[32, 0] [32, 32] [1, 1] : memref<64x32xf32, 3> to memref<32x32xf32, strided<[32, 1], offset: 1024>, 3>
 
-    //CHECK: %[[base_buffer:.*]], %[[offset:.*]], %[[sizes:.*]]:2, %[[strides:.*]]:2 = memref.extract_strided_metadata %{{.*}} : memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> -> memref<f32, 3>, index, index, index, index, index
-    //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[base_buffer]] : memref<f32, 3> -> index
-    //CHECK: %[[ptr_i32:.*]] = arith.index_castui %[[intptr]] : index to i32
-    //CHECK: %[[offset_i32:.*]] = arith.index_castui %[[offset]] : index to i32
-    //CHECK: %[[c4_i32:.*]] = arith.constant 4 : i32
-    //CHECK: %[[mul:.*]] = arith.muli %[[offset_i32]], %[[c4_i32]] : i32
-    //CHECK: %[[add:.*]] = arith.addi %[[ptr_i32]], %[[mul]] : i32
+    //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[base_buffer:.*]] : memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> -> index
+    //CHECK-DAG: %[[ptr_i32:.*]] = arith.index_castui %[[intptr]] : index to i32
+    //CHECK-DAG: %[[offset_i32:.*]] = arith.index_castui %[[offset:.*]] : index to i32
+    //CHECK-DAG: %[[c4_i32:.*]] = arith.constant 4 : i32
+    //CHECK-DAG: %[[mul:.*]] = arith.muli %[[offset_i32]], %[[c4_i32]] : i32
+    //CHECK-DAG: %[[add:.*]] = arith.addi %[[ptr_i32]], %[[mul]] : i32
 
     %0 = xegpu.create_mem_desc %subview : memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> -> !xegpu.mem_desc<32x32xf32>
 
-    //CHECK: %[[TID:.*]] = gpu.thread_id x
-    //CHECK: %[[C1:.*]] = arith.constant 1 : index
-    //CHECK: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index
-     //CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, {{.*}} : i32
-    //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32
+    //CHECK-DAG: %[[TID:.*]] = gpu.thread_id x
+    //CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+    //CHECK-DAG: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index
+     //CHECK-DAG: %[[MUL2:.*]] = arith.muli {{.*}}, {{.*}} : i32
+    //CHECK-DAG: llvm.load {{.*}} : !llvm.ptr<3> -> f32
 
     %tid_x = gpu.thread_id x
   
@@ -67,25 +66,25 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
   gpu.func @load_store_matrix_blocked_strided(%arg0: memref<4096xi8, 3>) -> f16 {
     %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
 
-    //CHECK: %[[tid_x:.*]] = gpu.thread_id x
-    //CHECK: %[[c13:.*]] = arith.constant 13 : index
-    //CHECK: %[[c16:.*]] = arith.constant 16 : index
-    //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c13]], %[[c16]] : index
-    //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c13]], %[[c16]] : index
-    //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
-    //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
-    //CHECK: %[[c0:.*]] = arith.constant 0 : index
-    //CHECK: %[[c256:.*]] = arith.constant 256 : index
-    //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
-    //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
-    //CHECK: %[[c512:.*]] = arith.constant 512 : index
-    //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index
-    //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
-    //CHECK: %[[c1:.*]] = arith.constant 1 : index
-    //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index
-    //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
-    //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index
-    //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
+    //CHECK-DAG: %[[tid_x:.*]] = gpu.thread_id x
+    //CHECK-DAG: %[[c13:.*]] = arith.constant 13 : index
+    //CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index
+    //CHECK-DAG: %[[offsetx_0:.*]] = arith.divsi %[[c13]], %[[c16]] : index
+    //CHECK-DAG: %[[offsetx_1:.*]] = arith.remsi %[[c13]], %[[c16]] : index
+    //CHECK-DAG: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
+    //CHECK-DAG: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
+    //CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+    //CHECK-DAG: %[[c256:.*]] = arith.constant 256 : index
+    //CHECK-DAG: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
+    //CHECK-DAG: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
+    //CHECK-DAG: %[[c512:.*]] = arith.constant 512 : index
+    //CHECK-DAG: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index
+    //CHECK-DAG: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
+    //CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+    //CHECK-DAG: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index
+    //CHECK-DAG: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
+    //CHECK-DAG: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index
+    //CHECK-DAG: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
 
     //CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> f16
  
@@ -108,29 +107,29 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
 
     %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
     
-    //CHECK: %[[tid_x:.*]] = gpu.thread_id x
-    //CHECK: %[[c19:.*]] = arith.constant 19 : index
+    //CHECK-DAG: %[[tid_x:.*]] = gpu.thread_id x
+    //CHECK-DAG: %[[c19:.*]] = arith.constant 19 : index
     %tid_x = gpu.thread_id x
     %c19 = arith.constant 19: index
     
-    //CHECK: %[[c16:.*]] = arith.constant 16 : index
-    //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index
-    //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index
-    //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
-    //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
-    //CHECK: %[[c0:.*]] = arith.constant 0 : index
-    //CHECK: %[[c1024:.*]] = arith.constant 1024 : index
-    //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c1024]] : index
-    //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
-    //CHECK: %[[c256:.*]] = arith.constant 256 : index
-    //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c256]] : index
-    //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
-    //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c16]] : index
-    //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
-    //CHECK: %[[c1:.*]] = arith.constant 1 : index
-    //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c1]] : index
-    //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
-    //CHECK: %[[loaded:.*]] = llvm.load {{.*}} : !llvm.ptr<3> -> f16
+    //CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index
+    //CHECK-DAG: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index
+    //CHECK-DAG: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index
+    //CHECK-DAG: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
+    //CHECK-DAG: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
+    //CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+    //CHECK-DAG: %[[c1024:.*]] = arith.constant 1024 : index
+    //CHECK-DAG: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c1024]] : index
+    //CHECK-DAG: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
+    //CHECK-DAG: %[[c256:.*]] = arith.constant 256 : index
+    //CHECK-DAG: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c256]] : index
+    //CHECK-DAG: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
+    //CHECK-DAG: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c16]] : index
+    //CHECK-DAG: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
+    //CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+    //CHECK-DAG: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c1]] : index
+    //CHECK-DAG: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
+    //CHECK-DAG: %[[loaded:.*]] = llvm.load {{.*}} : !llvm.ptr<3> -> f16
     %1 = xegpu.load_matrix %0[%c19, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> f16
     
     //CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3>
@@ -146,24 +145,24 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
   gpu.func @load_store_matrix_blocked_strided_return_vector(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
     %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
 
-    //CHECK: %[[tid_x:.*]] = gpu.thread_id x
-    //CHECK: %[[c16:.*]] = arith.constant 16 : index
-    //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
-    //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
-    //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
-    //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
-    //CHECK: %[[c0:.*]] = arith.constant 0 : index
-    //CHECK: %[[c256:.*]] = arith.constant 256 : index
-    //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
-    //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
-    //CHECK: %[[c512:.*]] = arith.constant 512 : index
-    //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index
-    //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
-    //CHECK: %[[c1:.*]] = arith.constant 1 : index
-    //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index
-    //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
-    //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index
-    //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
+    //CHECK-DAG: %[[tid_x:.*]] = gpu.thread_id x
+    //CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index
+    //CHECK-DAG: %[[offsetx_0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
+    //CHECK-DAG: %[[offsetx_1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
+    //CHECK-DAG: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
+    //CHECK-DAG: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
+    //CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+    //CHECK-DAG: %[[c256:.*]] = arith.constant 256 : index
+    //CHECK-DAG: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
+    //CHECK-DAG: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
+    //CHECK-DAG: %[[c512:.*]] = arith.constant 512 : index
+    //CHECK-DAG: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index
+    //CHECK-DAG: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
+    //CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+    //CHECK-DAG: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index
+    //CHECK-DAG: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
+    //CHECK-DAG: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index
+    //CHECK-DAG: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
 
     //CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> vector<8xf16>
      
@@ -183,9 +182,9 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
   //CHECK-LABEL: load_store_matrix_blocked_subgroupblockio
   gpu.func @load_store_matrix_blocked_subgroupblockio(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
 
-    //CHECK: %[[base_buffer:.*]], %[[offset:.*]], %[[sizes:.*]], %[[strides:.*]] = memref.extract_strided_metadata %arg0 : memref<4096xi8, 3> -> memref<i8, 3>, index, index, index
-    //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[base_buffer]] : memref<i8, 3> -> index
-    //CHECK: %[[basePtrI32:.*]] = arith.index_castui %[[intptr]] : index to i32
+    //CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+    //CHECK-DAG: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[base_buffer:.*]] : memref<4096xi8, 3> -> index
+    //CHECK-DAG: %[[basePtrI32:.*]] = arith.index_castui %[[intptr]] : index to i32
      %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
 
     //CHECK: %[[c16:.*]] = arith.constant 16 : index
@@ -193,29 +192,28 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
     %c16 = arith.constant 16 : index
     %c48 = arith.constant 48 : index
 
-    //CHECK: %[[offset0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
-    //CHECK: %[[offset1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
-    //CHECK: %[[offset2:.*]] = arith.divsi %[[c48]], %[[c16]] : index
-    //CHECK: %[[offset3:.*]] = arith.remsi %[[c48]], %[[c16]] : index
-    //CHECK: %[[c0:.*]] = arith.constant 0 : index
-    //CHECK: %[[c1024:.*]] = arith.constant 1024 : index
-    //CHECK: %[[mul0:.*]] = arith.muli %[[offset0]], %[[c1024]] : index
-    //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
-    //CHECK: %[[c256:.*]] = arith.constant 256 : index
-    //CHECK: %[[mul1:.*]] = arith.muli %[[offset2]], %[[c256]] : index
-    //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
-    //CHECK: %[[mul2:.*]] = arith.muli %[[offset1]], %[[c16]] : index
-    //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
-    //CHECK: %[[c1:.*]] = arith.constant 1 : index
-    //CHECK: %[[mul3:.*]] = arith.muli %[[offset3]], %[[c1]] : index
-    //CHECK: %[[linearOffset:.*]] = arith.addi %[[mul3]], %[[add2]] : index
-    //CHECK: %[[linearOffsetI64:.*]] = arith.index_castui %[[linearOffset]] : index to i32
-    //CHECK: %[[c2:.*]] = arith.constant 2 : i32
-    //CHECK: %[[byteOffset:.*]] = arith.muli %[[linearOffsetI64]], %[[c2]] : i32
-    //CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI32:.*]], %[[byteOffset]] : i32
-    //CHECK: %[[ptr:.*]] = llvm.inttoptr %[[finalPtr]] : i32 to !llvm.ptr<3>
-    //CHECK: %[[loadedI16:.*]] = xevm.blockload %[[ptr]] : (!llvm.ptr<3>) -> vector<8xi16>
-    //CHECK: %[[loaded:.*]] = vector.bitcast %[[loadedI16]] : vector<8xi16> to vector<8xf16>
+    //CHECK-DAG: %[[offset0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
+    //CHECK-DAG: %[[offset1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
+    //CHECK-DAG: %[[offset2:.*]] = arith.divsi %[[c48]], %[[c16]] : index
+    //CHECK-DAG: %[[offset3:.*]] = arith.remsi %[[c48]], %[[c16]] : index
+    //CHECK-DAG: %[[c1024:.*]] = arith.constant 1024 : index
+    //CHECK-DAG: %[[mul0:.*]] = arith.muli %[[offset0]], %[[c1024]] : index
+    //CHECK-DAG: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
+    //CHECK-DAG: %[[c256:.*]] = arith.constant 256 : index
+    //CHECK-DAG: %[[mul1:.*]] = arith.muli %[[offset2]], %[[c256]] : index
+    //CHECK-DAG: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
+    //CHECK-DAG: %[[mul2:.*]] = arith.muli %[[offset1]], %[[c16]] : index
+    //CHECK-DAG: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
+    //CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+    //CHECK-DAG: %[[mul3:.*]] = arith.muli %[[offset3]], %[[c1]] : index
+    //CHECK-DAG: %[[linearOffset:.*]] = arith.addi %[[mul3]], %[[add2]] : index
+    //CHECK-DAG: %[[linearOffsetI64:.*]] = arith.index_castui %[[linearOffset]] : index to i32
+    //CHECK-DAG: %[[c2:.*]] = arith.constant 2 : i32
+    //CHECK-DAG: %[[byteOffset:.*]] = arith.muli %[[linearOffsetI64]], %[[c2]] : i32
+    //CHECK-DAG: %[[finalPtr:.*]] = arith.addi %[[basePtrI32:.*]], %[[byteOffset]] : i32
+    //CHECK-DAG: %[[ptr:.*]] = llvm.inttoptr %[[finalPtr]] : i32 to !llvm.ptr<3>
+    //CHECK-DAG: %[[loadedI16:.*]] = xevm.blockload %[[ptr]] : (!llvm.ptr<3>) -> vector<8xi16>
+    //CHECK-DAG: %[[loaded:.*]] = vector.bitcast %[[loadedI16]] : vector<8xi16> to vector<8xf16>
 
     %1 = xegpu.load_matrix %0[%c16, %c48] {subgroup_block_io}: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> vector<8xf16>
 
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
index afeae8be24b72..30fbb66ec9e58 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
@@ -14,7 +14,7 @@ gpu.module @load_store_check {
         // CHECK: %[[INTPTR_1:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST_0]] : memref<8x16xf32> -> index
         // CHECK: %[[ST_PTR_AS_I64:.*]] = arith.index_castui %[[INTPTR_1]] : index to i64
         // CHECK: %[[LD_CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64>
-        // CHECK: %[[LD_DESC_0:.*]] = vector.insert %[[LD_PTR_AS_I64]], %[[LD_CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
+        // CHECK: %[[LD_DESC_0:.*]] = vector.insert %[[LD_PTR_AS_I64_OFFSET:.*]], %[[LD_CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
         // CHECK: %[[LD_DESC_1:.*]] = vector.bitcast %[[LD_DESC_0]] : vector<4xi64> to vector<8xi32>
         // CHECK: %[[LD_DESC_2:.*]] = vector.insert {{.*}}, %[[LD_DESC_1]] [2] : i32 into vector<8xi32>
         // CHECK: %[[LD_DESC_3:.*]] = vector.insert {{.*}}, %[[LD_DESC_2]] [3] : i32 into vector<8xi32>
@@ -50,7 +50,7 @@ gpu.module @load_store_check {
         %loaded_modified = vector.insert %tid_x_f32, %loaded[0] : f32 into vector<8xf32>
 
         // CHECK: %[[CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64>
-        // CHECK: %[[DESC_0:.*]] = vector.insert %[[ST_PTR_AS_I64]], %[[CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
+        // CHECK: %[[DESC_0:.*]] = vector.insert %[[ST_PTR_AS_I64_OFFSET:.*]], %[[CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
         // CHECK: %[[DESC_1:.*]] = vector.bitcast %[[DESC_0]] : vector<4xi64> to vector<8xi32>
         // CHECK: %[[DESC_2:.*]] = vector.insert {{.*}}, %[[DESC_1]] [2] : i32 into vector<8xi32>
         // CHECK: %[[DESC_3:.*]] = vector.insert {{.*}}, %[[DESC_2]] [3] : i32 into vector<8xi32>



More information about the Mlir-commits mailing list