[Mlir-commits] [mlir] [MLIR][XeGPU] Support subview memref: handling the base address during xegpu to xevm type conversion (PR #170541)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Dec 3 11:33:35 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Jianhui Li (Jianhui-Li)

<details>
<summary>Changes</summary>

During the XeGPU-to-XeVM type conversion, a memref is lowered to its base address. This PR extends the conversion to correctly handle memrefs that include an offset, such as those generated by memref.subview.


---

Patch is 25.66 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/170541.diff


6 Files Affected:

- (modified) mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h (+3) 
- (modified) mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp (+51-8) 
- (modified) mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp (+12) 
- (modified) mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir (+5-3) 
- (modified) mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir (+96-92) 
- (modified) mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir (+2-2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 58092c3bb9ed2..b5978dc8d7b74 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -175,6 +175,9 @@ template <typename T>
 int getLargestDivisor(T dim, ArrayRef<T> candidates,
                       ArrayRef<T> candidateMultiples = {});
 
+/// Checks if the given MemRefType refers to shared memory.
+bool isSharedMemRef(const MemRefType &memrefTy);
+
 } // namespace xegpu
 
 } // namespace mlir
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 7f1ec17ce0ae8..a1c2745864dc6 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -991,15 +991,14 @@ struct ConvertXeGPUToXeVMPass
     });
 
     typeConverter.addConversion([&](MemRefType type) -> Type {
-      if (type.getMemorySpaceAsInt() == 3)
-        return IntegerType::get(&getContext(), 32);
-      return IntegerType::get(&getContext(), 64);
+      return IntegerType::get(&getContext(),
+                              (xegpu::isSharedMemRef(type) ? 32 : 64));
     });
 
     // LLVM type converter puts unrealized casts for the following cases:
     // add materialization casts to handle them.
 
-    // Materialization to convert memref to i64
+    // Materialization to convert memref to i64 or i32 depending on global/SLM
     auto memrefMaterializationCast = [](OpBuilder &builder, Type type,
                                         ValueRange inputs,
                                         Location loc) -> Value {
@@ -1007,11 +1006,55 @@ struct ConvertXeGPUToXeVMPass
         return {};
       auto input = inputs.front();
       if (auto memrefTy = dyn_cast<MemRefType>(input.getType())) {
+        unsigned rank = memrefTy.getRank();
+        Type indexType = builder.getIndexType();
 
-        Value addr =
-            memref::ExtractAlignedPointerAsIndexOp::create(builder, loc, input);
-        return arith::IndexCastUIOp::create(builder, loc, type, addr)
-            .getResult();
+        int64_t intOffsets;
+        SmallVector<int64_t> intStrides;
+        Value addr;
+        Value offset;
+        if (failed(memrefTy.getStridesAndOffset(intStrides, intOffsets))) {
+
+          // Result types: [base_memref, offset, stride0, stride1, ...,
+          // strideN-1, size0, size1, ..., sizeN-1]
+          SmallVector<Type> resultTypes{
+              MemRefType::get({}, memrefTy.getElementType(),
+                              MemRefLayoutAttrInterface(),
+                              memrefTy.getMemorySpace()),
+              indexType};
+          // strides + sizes
+          resultTypes.append(2 * rank, indexType);
+
+          auto meta = memref::ExtractStridedMetadataOp::create(
+              builder, loc, resultTypes, input);
+
+          addr = memref::ExtractAlignedPointerAsIndexOp::create(
+              builder, loc, meta.getBaseBuffer());
+          offset = meta.getOffset();
+
+        } else {
+          addr = memref::ExtractAlignedPointerAsIndexOp::create(builder, loc,
+                                                                input);
+          offset = arith::ConstantOp::create(builder, loc,
+                                             builder.getIndexAttr(intOffsets));
+        }
+
+        auto addr_casted =
+            arith::IndexCastUIOp::create(builder, loc, type, addr);
+        auto offset_casted =
+            arith::IndexCastUIOp::create(builder, loc, type, offset);
+
+        // Compute the final address: base address + byte offset
+        auto byte_size = arith::ConstantOp::create(
+            builder, loc, type,
+            builder.getIntegerAttr(type,
+                                   memrefTy.getElementTypeBitWidth() / 8));
+        auto byte_offset =
+            arith::MulIOp::create(builder, loc, offset_casted, byte_size);
+        auto addr_with_offset =
+            arith::AddIOp::create(builder, loc, addr_casted, byte_offset);
+
+        return addr_with_offset.getResult();
       }
       return {};
     };
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 91432b1c11304..eecbb7b907e9f 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -580,3 +580,15 @@ template int xegpu::getLargestDivisor<int>(int dim, ArrayRef<int> candidates,
 template int
 xegpu::getLargestDivisor<unsigned>(unsigned dim, ArrayRef<unsigned> candidates,
                                    ArrayRef<unsigned> candidateMultiples);
+
+/// Checks if the given MemRefType refers to shared memory.
+bool xegpu::isSharedMemRef(const MemRefType &memrefTy) {
+  Attribute attr = memrefTy.getMemorySpace();
+  if (!attr)
+    return false;
+  if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
+    return intAttr.getInt() == static_cast<int>(xevm::AddrSpace::SHARED);
+  if (auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr))
+    return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
+  return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
+}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
index 8b87b791c9fd3..242101955b900 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
@@ -8,7 +8,8 @@ gpu.module @create_nd_tdesc {
   gpu.func @create_nd_tdesc(%src: memref<16x32xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index,
   %stride1: index, %stride2: index, %offset1: index, %offset2: index, %dyn: memref<?x?xf16>) kernel {
         // CHECK: %[[INTPTR_5:.*]] = memref.extract_aligned_pointer_as_index %[[DYN]] : memref<?x?xf16> -> index
-        // CHECK: %[[VAR23:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64
+        // CHECK: %[[BASE_ADDR3:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64    
+        
         // CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index
         // CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64
         // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
@@ -39,7 +40,7 @@ gpu.module @create_nd_tdesc {
         // CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64
         // CHECK: %[[SHAPE_H2:.*]] = arith.trunci %[[C16_I64]] : i64 to i32
         // CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64>
-        // CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64>
+        // CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2_OFFSET:.*]], %[[VAR14]] [0] : i64 into vector<4xi64>
         // CHECK: %[[VAR16:.*]] = vector.bitcast %[[VAR15]] : vector<4xi64> to vector<8xi32>
         // CHECK: %[[VAR17:.*]] = vector.insert %[[SHAPE_W2]], %[[VAR16]] [2] : i32 into vector<8xi32>
         // CHECK: %[[VAR18:.*]] = vector.insert %[[SHAPE_H2]], %[[VAR17]] [3] : i32 into vector<8xi32>
@@ -53,13 +54,14 @@ gpu.module @create_nd_tdesc {
         %size_x = arith.constant 64 : index
         // CHECK: %[[C16:.*]] = arith.constant 16 : index
         %BLOCK_DMODEL = arith.constant 16 : index
+   
         // CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32>
         // CHECK: %[[C0_I32_6:.*]] = arith.constant 0 : i32
         // CHECK: %[[C0_I32_7:.*]] = arith.constant 0 : i32
         // CHECK: %[[VAR21:.*]] = arith.index_cast %[[C16]] : index to i32
         // CHECK: %[[VAR22:.*]] = arith.index_cast %[[C64]] : index to i32
         // CHECK: %[[VAR24:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64>
-        // CHECK: %[[VAR25:.*]] = vector.insert %[[VAR23]], %[[VAR24]] [0] : i64 into vector<4xi64>
+        // CHECK: %[[VAR25:.*]] = vector.insert %[[BASE_ADDR3_OFFSET:.*]], %[[VAR24]] [0] : i64 into vector<4xi64>
         // CHECK: %[[VAR26:.*]] = vector.bitcast %[[VAR25]] : vector<4xi64> to vector<8xi32>
         // CHECK: %[[VAR27:.*]] = vector.insert %[[VAR21]], %[[VAR26]] [2] : i32 into vector<8xi32>
         // CHECK: %[[VAR28:.*]] = vector.insert %[[VAR22]], %[[VAR27]] [3] : i32 into vector<8xi32>
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
index ac95a1a5707ea..179fd397d7074 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
@@ -33,22 +33,28 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
     
     %subview = memref.subview %view[32, 0] [32, 32] [1, 1] : memref<64x32xf32, 3> to memref<32x32xf32, strided<[32, 1], offset: 1024>, 3>
 
+    //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[base_buffer:.*]] : memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> -> index
+    //CHECK-DAG: %[[ptr_i32:.*]] = arith.index_castui %[[intptr]] : index to i32
+    //CHECK-DAG: %[[offset_i32:.*]] = arith.index_castui %[[offset:.*]] : index to i32
+    //CHECK-DAG: %[[c4_i32:.*]] = arith.constant 4 : i32
+    //CHECK-DAG: %[[mul:.*]] = arith.muli %[[offset_i32]], %[[c4_i32]] : i32
+    //CHECK-DAG: %[[add:.*]] = arith.addi %[[ptr_i32]], %[[mul]] : i32
+
     %0 = xegpu.create_mem_desc %subview : memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> -> !xegpu.mem_desc<32x32xf32>
 
-    //CHECK: %[[TID:.*]] = gpu.thread_id x
-    //CHECK: %[[C1:.*]] = arith.constant 1 : index
-    //CHECK: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index
-    //CHECK: %[[C4:.*]] = arith.constant 4 : i32
-    //CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4]] : i32
-    //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32
+    //CHECK-DAG: %[[TID:.*]] = gpu.thread_id x
+    //CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+    //CHECK-DAG: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index
+     //CHECK-DAG: %[[MUL2:.*]] = arith.muli {{.*}}, {{.*}} : i32
+    //CHECK-DAG: llvm.load {{.*}} : !llvm.ptr<3> -> f32
 
     %tid_x = gpu.thread_id x
- 
+  
     %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> f32
 
     //CHECK: llvm.store {{.*}}, {{.*}} : f32, !llvm.ptr<3>
 
-     xegpu.store_matrix %1, %0[%c0, %tid_x]: f32, !xegpu.mem_desc<32x32xf32>, index, index
+    xegpu.store_matrix %1, %0[%c0, %tid_x]: f32, !xegpu.mem_desc<32x32xf32>, index, index
 
     gpu.return %1: f32
   }
@@ -60,25 +66,25 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
   gpu.func @load_store_matrix_blocked_strided(%arg0: memref<4096xi8, 3>) -> f16 {
     %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
 
-    //CHECK: %[[tid_x:.*]] = gpu.thread_id x
-    //CHECK: %[[c13:.*]] = arith.constant 13 : index
-    //CHECK: %[[c16:.*]] = arith.constant 16 : index
-    //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c13]], %[[c16]] : index
-    //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c13]], %[[c16]] : index
-    //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
-    //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
-    //CHECK: %[[c0:.*]] = arith.constant 0 : index
-    //CHECK: %[[c256:.*]] = arith.constant 256 : index
-    //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
-    //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
-    //CHECK: %[[c512:.*]] = arith.constant 512 : index
-    //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index
-    //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
-    //CHECK: %[[c1:.*]] = arith.constant 1 : index
-    //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index
-    //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
-    //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index
-    //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
+    //CHECK-DAG: %[[tid_x:.*]] = gpu.thread_id x
+    //CHECK-DAG: %[[c13:.*]] = arith.constant 13 : index
+    //CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index
+    //CHECK-DAG: %[[offsetx_0:.*]] = arith.divsi %[[c13]], %[[c16]] : index
+    //CHECK-DAG: %[[offsetx_1:.*]] = arith.remsi %[[c13]], %[[c16]] : index
+    //CHECK-DAG: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
+    //CHECK-DAG: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
+    //CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+    //CHECK-DAG: %[[c256:.*]] = arith.constant 256 : index
+    //CHECK-DAG: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
+    //CHECK-DAG: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
+    //CHECK-DAG: %[[c512:.*]] = arith.constant 512 : index
+    //CHECK-DAG: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index
+    //CHECK-DAG: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
+    //CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+    //CHECK-DAG: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index
+    //CHECK-DAG: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
+    //CHECK-DAG: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index
+    //CHECK-DAG: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
 
     //CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> f16
  
@@ -99,33 +105,31 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
   //CHECK-LABEL: load_store_matrix_blocked_nostride
   gpu.func @load_store_matrix_blocked_nostride(%arg0: memref<4096xi8, 3>) -> f16 {
 
-    //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index
-    //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
     %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
     
-    //CHECK: %[[tid_x:.*]] = gpu.thread_id x
-    //CHECK: %[[c19:.*]] = arith.constant 19 : index
+    //CHECK-DAG: %[[tid_x:.*]] = gpu.thread_id x
+    //CHECK-DAG: %[[c19:.*]] = arith.constant 19 : index
     %tid_x = gpu.thread_id x
     %c19 = arith.constant 19: index
     
-    //CHECK: %[[c16:.*]] = arith.constant 16 : index
-    //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index
-    //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index
-    //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
-    //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
-    //CHECK: %[[c0:.*]] = arith.constant 0 : index
-    //CHECK: %[[c1024:.*]] = arith.constant 1024 : index
-    //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c1024]] : index
-    //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
-    //CHECK: %[[c256:.*]] = arith.constant 256 : index
-    //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c256]] : index
-    //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
-    //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c16]] : index
-    //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
-    //CHECK: %[[c1:.*]] = arith.constant 1 : index
-    //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c1]] : index
-    //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
-    //CHECK: %[[loaded:.*]] = llvm.load {{.*}} : !llvm.ptr<3> -> f16
+    //CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index
+    //CHECK-DAG: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index
+    //CHECK-DAG: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index
+    //CHECK-DAG: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
+    //CHECK-DAG: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
+    //CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+    //CHECK-DAG: %[[c1024:.*]] = arith.constant 1024 : index
+    //CHECK-DAG: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c1024]] : index
+    //CHECK-DAG: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
+    //CHECK-DAG: %[[c256:.*]] = arith.constant 256 : index
+    //CHECK-DAG: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c256]] : index
+    //CHECK-DAG: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
+    //CHECK-DAG: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c16]] : index
+    //CHECK-DAG: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
+    //CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+    //CHECK-DAG: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c1]] : index
+    //CHECK-DAG: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
+    //CHECK-DAG: %[[loaded:.*]] = llvm.load {{.*}} : !llvm.ptr<3> -> f16
     %1 = xegpu.load_matrix %0[%c19, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> f16
     
     //CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3>
@@ -141,24 +145,24 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
   gpu.func @load_store_matrix_blocked_strided_return_vector(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
     %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
 
-    //CHECK: %[[tid_x:.*]] = gpu.thread_id x
-    //CHECK: %[[c16:.*]] = arith.constant 16 : index
-    //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
-    //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
-    //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
-    //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
-    //CHECK: %[[c0:.*]] = arith.constant 0 : index
-    //CHECK: %[[c256:.*]] = arith.constant 256 : index
-    //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
-    //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
-    //CHECK: %[[c512:.*]] = arith.constant 512 : index
-    //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index
-    //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
-    //CHECK: %[[c1:.*]] = arith.constant 1 : index
-    //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index
-    //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
-    //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index
-    //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
+    //CHECK-DAG: %[[tid_x:.*]] = gpu.thread_id x
+    //CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index
+    //CHECK-DAG: %[[offsetx_0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
+    //CHECK-DAG: %[[offsetx_1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
+    //CHECK-DAG: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
+    //CHECK-DAG: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
+    //CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+    //CHECK-DAG: %[[c256:.*]] = arith.constant 256 : index
+    //CHECK-DAG: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
+    //CHECK-DAG: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
+    //CHECK-DAG: %[[c512:.*]] = arith.constant 512 : index
+    //CHECK-DAG: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index
+    //CHECK-DAG: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
+    //CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+    //CHECK-DAG: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index
+    //CHECK-DAG: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
+    //CHECK-DAG: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index
+    //CHECK-DAG: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
 
     //CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> vector<8xf16>
      
@@ -178,8 +182,9 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
   //CHECK-LABEL: load_store_matrix_blocked_subgroupblockio
   gpu.func @load_store_matrix_blocked_subgroupblockio(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
 
-    //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index
-    //CHECK: %[[basePtrI32:.*]] = arith.index_castui %[[intptr]] : index to i32
+    //CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+    //CHECK-DAG: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[base_buffer:.*]] : memref<4096xi8, 3> -> index
+    //CHECK-DAG: %[[basePtrI32:.*]] = arith.index_castui %[[intptr]] : index to i32
      %0 = xegpu.create_mem_...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/170541


More information about the Mlir-commits mailing list