[Mlir-commits] [mlir] 8adcf0a - [MLIR][XeGPU] Support subview memref: handling the base address during xegpu to xevm type conversion (#170541)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Dec 10 08:53:22 PST 2025


Author: Jianhui Li
Date: 2025-12-10T08:53:18-08:00
New Revision: 8adcf0ad5a9ff8649e98120e9101cf48c1585ead

URL: https://github.com/llvm/llvm-project/commit/8adcf0ad5a9ff8649e98120e9101cf48c1585ead
DIFF: https://github.com/llvm/llvm-project/commit/8adcf0ad5a9ff8649e98120e9101cf48c1585ead.diff

LOG: [MLIR][XeGPU] Support subview memref: handling the base address during xegpu to xevm type conversion (#170541)

During the XeGPU-to-XeVM type conversion, a memref is lowered to its
base address. This PR extends the conversion to correctly handle memrefs
that include an offset, such as those generated by memref.subview.

Added: 
    

Modified: 
    mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
    mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
    mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
    mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 0ecb50e719d58..8a06271eadd84 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -66,6 +66,18 @@ static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
   llvm_unreachable("Unknown XeGPU memory space");
 }
 
+/// Checks if the given MemRefType refers to shared memory.
+static bool isSharedMemRef(const MemRefType &memrefTy) {
+  Attribute attr = memrefTy.getMemorySpace();
+  if (!attr)
+    return false;
+  if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
+    return intAttr.getInt() == static_cast<int>(xevm::AddrSpace::SHARED);
+  if (auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr))
+    return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
+  return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
+}
+
 // Get same bitwidth flat vector type of new element type.
 static VectorType encodeVectorTypeTo(VectorType currentVecType,
                                      Type toElemType) {
@@ -1066,15 +1078,13 @@ struct ConvertXeGPUToXeVMPass
     });
 
     typeConverter.addConversion([&](MemRefType type) -> Type {
-      if (type.getMemorySpaceAsInt() == 3)
-        return IntegerType::get(&getContext(), 32);
-      return IntegerType::get(&getContext(), 64);
+      return IntegerType::get(&getContext(), (isSharedMemRef(type) ? 32 : 64));
     });
 
     // LLVM type converter puts unrealized casts for the following cases:
     // add materialization casts to handle them.
 
-    // Materialization to convert memref to i64
+    // Materialization to convert memref to i64 or i32 depending on global/SLM
     auto memrefMaterializationCast = [](OpBuilder &builder, Type type,
                                         ValueRange inputs,
                                         Location loc) -> Value {
@@ -1082,11 +1092,55 @@ struct ConvertXeGPUToXeVMPass
         return {};
       auto input = inputs.front();
       if (auto memrefTy = dyn_cast<MemRefType>(input.getType())) {
+        unsigned rank = memrefTy.getRank();
+        Type indexType = builder.getIndexType();
 
-        Value addr =
-            memref::ExtractAlignedPointerAsIndexOp::create(builder, loc, input);
-        return arith::IndexCastUIOp::create(builder, loc, type, addr)
-            .getResult();
+        int64_t intOffsets;
+        SmallVector<int64_t> intStrides;
+        Value addr;
+        Value offset;
+        if (succeeded(memrefTy.getStridesAndOffset(intStrides, intOffsets)) &&
+            ShapedType::isStatic(intOffsets)) {
+          addr = memref::ExtractAlignedPointerAsIndexOp::create(builder, loc,
+                                                                input);
+          offset = arith::ConstantOp::create(builder, loc,
+                                             builder.getIndexAttr(intOffsets));
+        } else {
+
+          // Result types: [base_memref, offset, stride0, stride1, ...,
+          // strideN-1, size0, size1, ..., sizeN-1]
+          SmallVector<Type> resultTypes{
+              MemRefType::get({}, memrefTy.getElementType(),
+                              MemRefLayoutAttrInterface(),
+                              memrefTy.getMemorySpace()),
+              indexType};
+          // strides + sizes
+          resultTypes.append(2 * rank, indexType);
+
+          auto meta = memref::ExtractStridedMetadataOp::create(
+              builder, loc, resultTypes, input);
+
+          addr = memref::ExtractAlignedPointerAsIndexOp::create(
+              builder, loc, meta.getBaseBuffer());
+          offset = meta.getOffset();
+        }
+
+        auto addrCasted =
+            arith::IndexCastUIOp::create(builder, loc, type, addr);
+        auto offsetCasted =
+            arith::IndexCastUIOp::create(builder, loc, type, offset);
+
+        // Compute the final address: base address + byte offset
+        auto byteSize = arith::ConstantOp::create(
+            builder, loc, type,
+            builder.getIntegerAttr(type,
+                                   memrefTy.getElementTypeBitWidth() / 8));
+        auto byteOffset =
+            arith::MulIOp::create(builder, loc, offsetCasted, byteSize);
+        auto addrWithOffset =
+            arith::AddIOp::create(builder, loc, addrCasted, byteOffset);
+
+        return addrWithOffset.getResult();
       }
       return {};
     };

diff  --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
index 9a1e2cb3c7de0..34654126ce8d2 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
@@ -37,7 +37,7 @@ gpu.module @create_nd_tdesc {
         // CHECK: %[[C32_I64_2:.*]] = arith.constant 32 : i64
         // CHECK: %[[PITCH2:.*]] = arith.trunci %[[C32_I64_2]] : i64 to i32
         // CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64>
-        // CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64>
+        // CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2_OFFSET:.*]], %[[VAR14]] [0] : i64 into vector<4xi64>
         // CHECK: %[[VAR16:.*]] = vector.bitcast %[[VAR15]] : vector<4xi64> to vector<8xi32>
         // CHECK: %[[VAR17:.*]] = vector.insert %[[SHAPE_W2]], %[[VAR16]] [2] : i32 into vector<8xi32>
         // CHECK: %[[VAR18:.*]] = vector.insert %[[SHAPE_H2]], %[[VAR17]] [3] : i32 into vector<8xi32>
@@ -55,7 +55,7 @@ gpu.module @create_nd_tdesc {
         // CHECK: %[[SHAPE_H3:.*]] = arith.index_cast %[[C64]] : index to i32
         // CHECK: %[[PITCH3:.*]] = arith.index_cast %[[C16]] : index to i32
         // CHECK: %[[VAR25:.*]] = vector.bitcast %[[CST_3]] : vector<8xi32> to vector<4xi64>
-        // CHECK: %[[VAR26:.*]] = vector.insert %[[DYN_ADDR]], %[[VAR25]] [0] : i64 into vector<4xi64>
+        // CHECK: %[[VAR26:.*]] = vector.insert %[[DYN_ADDR_OFFSET:.*]], %[[VAR25]] [0] : i64 into vector<4xi64>
         // CHECK: %[[VAR27:.*]] = vector.bitcast %[[VAR26]] : vector<4xi64> to vector<8xi32>
         // CHECK: %[[VAR28:.*]] = vector.insert %[[SHAPE_W3]], %[[VAR27]] [2] : i32 into vector<8xi32>
         // CHECK: %[[VAR29:.*]] = vector.insert %[[SHAPE_H3]], %[[VAR28]] [3] : i32 into vector<8xi32>

diff  --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
index 3a3769f3a4f70..fa683175693be 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
@@ -1,11 +1,19 @@
-// RUN: mlir-opt  -split-input-file -convert-xegpu-to-xevm -cse %s | FileCheck %s
+// RUN: mlir-opt  -split-input-file -convert-xegpu-to-xevm %s | FileCheck %s
 
 gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
 
- // e.g. for mem_desc<32x32xf16, @strides=[1, 16]>
+  // e.g. for mem_desc<32x32xf16, @strides=[1, 16]>
   // its memory layout tuple is (blocked shape = [1,1,32,32],strides=[1024,1024,32,1])
   //CHECK-LABEL: load_store_matrix_plain
   gpu.func @load_store_matrix_plain(%arg0: memref<4096xi8, 3>) -> f32 {
+
+    //CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index
+    //CHECK: %[[C0:.*]] = arith.constant 0 : index
+    //CHECK: %[[CAST0:.*]] = arith.index_castui %[[INTPTR]] : index to i32
+    //CHECK: %[[CAST1:.*]] = arith.index_castui %[[C0]] : index to i32
+    //CHECK: %[[C1_I32:.*]] = arith.constant 1 : i32
+    //CHECK: %[[MUL:.*]] = arith.muli %[[CAST1]], %[[C1_I32]] : i32
+    //CHECK: %[[ADD:.*]] = arith.addi %[[CAST0]], %[[MUL]] : i32
     %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32>
 
     //CHECK: %[[TID:.*]] = gpu.thread_id x
@@ -27,28 +35,33 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
   }
 
   //CHECK-LABEL: load_store_matrix_plain_2d_input
-  gpu.func @load_store_matrix_plain_2d_input(%arg0: memref<8192xi8, 3>) -> f32 {
+  gpu.func @load_store_matrix_plain_2d_input(%arg0: memref<1024xi8, 3>) -> f32 {
     %c0 = arith.constant 0 : index
-    %view = memref.view %arg0[%c0][]: memref<8192xi8, 3> to memref<64x32xf32, 3>
+    %view = memref.view %arg0[%c0][]: memref<1024xi8, 3> to memref<64x32xf32, 3>
 
     %subview = memref.subview %view[32, 0] [32, 32] [1, 1] : memref<64x32xf32, 3> to memref<32x32xf32, strided<[32, 1], offset: 1024>, 3>
 
+    //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[base_buffer:.*]] : memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> -> index
+    //CHECK: %[[ptr_i32:.*]] = arith.index_castui %[[intptr]] : index to i32
+    //CHECK: %[[offset_i32:.*]] = arith.index_castui %[[offset:.*]] : index to i32
+    //CHECK: %[[c4_i32:.*]] = arith.constant 4 : i32
+    //CHECK: %[[mul:.*]] = arith.muli %[[offset_i32]], %[[c4_i32]] : i32
+    //CHECK: %[[add:.*]] = arith.addi %[[ptr_i32]], %[[mul]] : i32
+
     %0 = xegpu.create_mem_desc %subview : memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> -> !xegpu.mem_desc<32x32xf32>
 
     //CHECK: %[[TID:.*]] = gpu.thread_id x
     //CHECK: %[[C1:.*]] = arith.constant 1 : index
     //CHECK: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index
-    //CHECK: %[[C4:.*]] = arith.constant 4 : i32
-    //CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4]] : i32
+    //CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, {{.*}} : i32
     //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32
 
     %tid_x = gpu.thread_id x
-
     %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> f32
 
     //CHECK: llvm.store {{.*}}, {{.*}} : f32, !llvm.ptr<3>
 
-     xegpu.store_matrix %1, %0[%c0, %tid_x]: f32, !xegpu.mem_desc<32x32xf32>, index, index
+    xegpu.store_matrix %1, %0[%c0, %tid_x]: f32, !xegpu.mem_desc<32x32xf32>, index, index
 
     gpu.return %1: f32
   }
@@ -64,9 +77,12 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
     //CHECK: %[[c13:.*]] = arith.constant 13 : index
     //CHECK: %[[c16:.*]] = arith.constant 16 : index
     //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c13]], %[[c16]] : index
-    //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c13]], %[[c16]] : index
-    //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
-    //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
+    //CHECK: %[[c16_0:.*]] = arith.constant 16 : index
+    //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c13]], %[[c16_0]] : index
+    //CHECK: %[[c16_1:.*]] = arith.constant 16 : index
+    //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16_1]] : index
+    //CHECK: %[[c16_2:.*]] = arith.constant 16 : index
+    //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16_2]] : index
     //CHECK: %[[c0:.*]] = arith.constant 0 : index
     //CHECK: %[[c256:.*]] = arith.constant 256 : index
     //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
@@ -77,11 +93,15 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
     //CHECK: %[[c1:.*]] = arith.constant 1 : index
     //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index
     //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
-    //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index
+    //CHECK: %[[c16_4:.*]] = arith.constant 16 : index
+    //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16_4]] : index
     //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
-
-    //CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> f16
-
+    //CHECK: %[[cast:.*]] = arith.index_castui %[[add3]] : index to i32
+    //CHECK: %[[c2_i32:.*]] = arith.constant 2 : i32
+    //CHECK: %[[byte_offset:.*]] = arith.muli %[[cast]], %[[c2_i32]] : i32
+    //CHECK: %[[final_ptr:.*]] = arith.addi {{.*}}, %[[byte_offset]] : i32
+    //CHECK: %[[ptr:.*]] = llvm.inttoptr %[[final_ptr]] : i32 to !llvm.ptr<3>
+    //CHECK: %[[loaded:.*]] = llvm.load %[[ptr]] : !llvm.ptr<3> -> f16
 
     %tid_x = gpu.thread_id x
     %c13 = arith.constant 13 : index
@@ -93,14 +113,17 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
     gpu.return %1: f16
   }
 
-
   // e.g. for mem_desc<32x64xf16, @block=[16, 16]>
   // its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
   //CHECK-LABEL: load_store_matrix_blocked_nostride
   gpu.func @load_store_matrix_blocked_nostride(%arg0: memref<4096xi8, 3>) -> f16 {
-
     //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index
-    //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
+    //CHECK: %[[c0:.*]] = arith.constant 0 : index
+    //CHECK: %[[cast0:.*]] = arith.index_castui %[[intptr]] : index to i32
+    //CHECK: %[[cast1:.*]] = arith.index_castui %[[c0]] : index to i32
+    //CHECK: %[[c1_i32:.*]] = arith.constant 1 : i32
+    //CHECK: %[[mul:.*]] = arith.muli %[[cast1]], %[[c1_i32]] : i32
+    //CHECK: %[[add:.*]] = arith.addi %[[cast0]], %[[mul]] : i32
     %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
 
     //CHECK: %[[tid_x:.*]] = gpu.thread_id x
@@ -110,9 +133,12 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
 
     //CHECK: %[[c16:.*]] = arith.constant 16 : index
     //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index
-    //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index
-    //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
-    //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
+    //CHECK: %[[c16_0:.*]] = arith.constant 16 : index
+    //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16_0]] : index
+    //CHECK: %[[c16_1:.*]] = arith.constant 16 : index
+    //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16_1]] : index
+    //CHECK: %[[c16_2:.*]] = arith.constant 16 : index
+    //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16_2]] : index
     //CHECK: %[[c0:.*]] = arith.constant 0 : index
     //CHECK: %[[c1024:.*]] = arith.constant 1024 : index
     //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c1024]] : index
@@ -120,12 +146,18 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
     //CHECK: %[[c256:.*]] = arith.constant 256 : index
     //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c256]] : index
     //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
-    //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c16]] : index
+    //CHECK: %[[c16_4:.*]] = arith.constant 16 : index
+    //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c16_4]] : index
     //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
     //CHECK: %[[c1:.*]] = arith.constant 1 : index
     //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c1]] : index
     //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
-    //CHECK: %[[loaded:.*]] = llvm.load {{.*}} : !llvm.ptr<3> -> f16
+    //CHECK: %[[cast:.*]] = arith.index_castui %[[add3]] : index to i32
+    //CHECK: %[[c2_i32:.*]] = arith.constant 2 : i32
+    //CHECK: %[[byte_offset:.*]] = arith.muli %[[cast]], %[[c2_i32]] : i32
+    //CHECK: %[[final_ptr:.*]] = arith.addi {{.*}}, %[[byte_offset]] : i32
+    //CHECK: %[[ptr:.*]] = llvm.inttoptr %[[final_ptr]] : i32 to !llvm.ptr<3>
+    //CHECK: %[[loaded:.*]] = llvm.load %[[ptr]] : !llvm.ptr<3> -> f16
     %1 = xegpu.load_matrix %0[%c19, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> f16
 
     //CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3>
@@ -143,10 +175,14 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
 
     //CHECK: %[[tid_x:.*]] = gpu.thread_id x
     //CHECK: %[[c16:.*]] = arith.constant 16 : index
-    //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
-    //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
-    //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
-    //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
+    //CHECK: %[[c16_0:.*]] = arith.constant 16 : index
+    //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c16]], %[[c16_0]] : index
+    //CHECK: %[[c16_1:.*]] = arith.constant 16 : index
+    //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c16]], %[[c16_1]] : index
+    //CHECK: %[[c16_2:.*]] = arith.constant 16 : index
+    //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16_2]] : index
+    //CHECK: %[[c16_3:.*]] = arith.constant 16 : index
+    //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16_3]] : index
     //CHECK: %[[c0:.*]] = arith.constant 0 : index
     //CHECK: %[[c256:.*]] = arith.constant 256 : index
     //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
@@ -157,10 +193,15 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
     //CHECK: %[[c1:.*]] = arith.constant 1 : index
     //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index
     //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
-    //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index
+    //CHECK: %[[c16_5:.*]] = arith.constant 16 : index
+    //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16_5]] : index
     //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
-
-    //CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> vector<8xf16>
+    //CHECK: %[[cast:.*]] = arith.index_castui %[[add3]] : index to i32
+    //CHECK: %[[c2_i32:.*]] = arith.constant 2 : i32
+    //CHECK: %[[byte_offset:.*]] = arith.muli %[[cast]], %[[c2_i32]] : i32
+    //CHECK: %[[final_ptr:.*]] = arith.addi {{.*}}, %[[byte_offset]] : i32
+    //CHECK: %[[ptr:.*]] = llvm.inttoptr %[[final_ptr]] : i32 to !llvm.ptr<3>
+    //CHECK: %[[loaded:.*]] = llvm.load %[[ptr]] : !llvm.ptr<3> -> vector<8xf16>
 
     %tid_x = gpu.thread_id x
     %c16 = arith.constant 16 : index
@@ -179,42 +220,52 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
   gpu.func @load_store_matrix_blocked_subgroupblockio(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
 
     //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index
-    //CHECK: %[[basePtrI32:.*]] = arith.index_castui %[[intptr]] : index to i32
-     %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
-
+    //CHECK: %[[c0:.*]] = arith.constant 0 : index
+    //CHECK: %[[cast0:.*]] = arith.index_castui %[[intptr]] : index to i32
+    //CHECK: %[[cast1:.*]] = arith.index_castui %[[c0]] : index to i32
+    //CHECK: %[[c1_i32:.*]] = arith.constant 1 : i32
+    //CHECK: %[[mul:.*]] = arith.muli %[[cast1]], %[[c1_i32]] : i32
+    //CHECK: %[[add:.*]] = arith.addi %[[cast0]], %[[mul]] : i32
     //CHECK: %[[c16:.*]] = arith.constant 16 : index
     //CHECK: %[[c48:.*]] = arith.constant 48 : index
-    %c16 = arith.constant 16 : index
-    %c48 = arith.constant 48 : index
-
-    //CHECK: %[[offset0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
-    //CHECK: %[[offset1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
-    //CHECK: %[[offset2:.*]] = arith.divsi %[[c48]], %[[c16]] : index
-    //CHECK: %[[offset3:.*]] = arith.remsi %[[c48]], %[[c16]] : index
-    //CHECK: %[[c0:.*]] = arith.constant 0 : index
+    //CHECK: %[[c16_0:.*]] = arith.constant 16 : index
+    //CHECK: %[[divsi0:.*]] = arith.divsi %[[c16]], %[[c16_0]] : index
+    //CHECK: %[[c16_1:.*]] = arith.constant 16 : index
+    //CHECK: %[[remsi0:.*]] = arith.remsi %[[c16]], %[[c16_1]] : index
+    //CHECK: %[[c16_2:.*]] = arith.constant 16 : index
+    //CHECK: %[[divsi1:.*]] = arith.divsi %[[c48]], %[[c16_2]] : index
+    //CHECK: %[[c16_3:.*]] = arith.constant 16 : index
+    //CHECK: %[[remsi1:.*]] = arith.remsi %[[c48]], %[[c16_3]] : index
+    //CHECK: %[[c0_4:.*]] = arith.constant 0 : index
     //CHECK: %[[c1024:.*]] = arith.constant 1024 : index
-    //CHECK: %[[mul0:.*]] = arith.muli %[[offset0]], %[[c1024]] : index
-    //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
+    //CHECK: %[[mul0:.*]] = arith.muli %[[divsi0]], %[[c1024]] : index
+    //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0_4]] : index
     //CHECK: %[[c256:.*]] = arith.constant 256 : index
-    //CHECK: %[[mul1:.*]] = arith.muli %[[offset2]], %[[c256]] : index
+    //CHECK: %[[mul1:.*]] = arith.muli %[[divsi1]], %[[c256]] : index
     //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
-    //CHECK: %[[mul2:.*]] = arith.muli %[[offset1]], %[[c16]] : index
+    //CHECK: %[[c16_5:.*]] = arith.constant 16 : index
+    //CHECK: %[[mul2:.*]] = arith.muli %[[remsi0]], %[[c16_5]] : index
     //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
     //CHECK: %[[c1:.*]] = arith.constant 1 : index
-    //CHECK: %[[mul3:.*]] = arith.muli %[[offset3]], %[[c1]] : index
-    //CHECK: %[[linearOffset:.*]] = arith.addi %[[mul3]], %[[add2]] : index
-    //CHECK: %[[linearOffsetI64:.*]] = arith.index_castui %[[linearOffset]] : index to i32
-    //CHECK: %[[c2:.*]] = arith.constant 2 : i32
-    //CHECK: %[[byteOffset:.*]] = arith.muli %[[linearOffsetI64]], %[[c2]] : i32
-    //CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI32]], %[[byteOffset]] : i32
-    //CHECK: %[[ptr:.*]] = llvm.inttoptr %[[finalPtr]] : i32 to !llvm.ptr<3>
-    //CHECK: %[[loadedI16:.*]] = xevm.blockload %[[ptr]] : (!llvm.ptr<3>) -> vector<8xi16>
-    //CHECK: %[[loaded:.*]] = vector.bitcast %[[loadedI16]] : vector<8xi16> to vector<8xf16>
+    //CHECK: %[[mul3:.*]] = arith.muli %[[remsi1]], %[[c1]] : index
+    //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
+    //CHECK: %[[cast:.*]] = arith.index_castui %[[add3]] : index to i32
+    //CHECK: %[[c2_i32:.*]] = arith.constant 2 : i32
+    //CHECK: %[[byte_offset:.*]] = arith.muli %[[cast]], %[[c2_i32]] : i32
+    //CHECK: %[[final_ptr:.*]] = arith.addi %[[add]], %[[byte_offset]] : i32
+    //CHECK: %[[ptr:.*]] = llvm.inttoptr %[[final_ptr]] : i32 to !llvm.ptr<3>
+    //CHECK: %[[blockload:.*]] = xevm.blockload %[[ptr]] : (!llvm.ptr<3>) -> vector<8xi16>
+    //CHECK: %[[loaded:.*]] = vector.bitcast %[[blockload]] : vector<8xi16> to vector<8xf16>
+    
+    %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
+
+    %c16 = arith.constant 16 : index
+    %c48 = arith.constant 48 : index
 
     %1 = xegpu.load_matrix %0[%c16, %c48] {subgroup_block_io}: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> vector<8xf16>
 
     //CHECK: %[[storeDataI16:.*]] = vector.bitcast %[[loaded]] : vector<8xf16> to vector<8xi16>
-    //CHECK: xevm.blockstore %[[ptr]], %[[storeDataI16]] : (!llvm.ptr<3>, vector<8xi16>)
+    //CHECK: xevm.blockstore %[[ptr2:.*]], %[[storeDataI16]] : (!llvm.ptr<3>, vector<8xi16>)
 
     xegpu.store_matrix %1, %0[%c16, %c48] {subgroup_block_io}: vector<8xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index
 
@@ -232,4 +283,29 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
     xegpu.store_matrix %loaded_2, %matrixdesc[16,0] : vector<1x16xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
     gpu.return
   }
+
+  //CHECK-LABEL: create_memdesc_from_subview
+  gpu.func @create_memdesc_from_subview(%arg0: memref<256x16xbf16, 3>) -> vector<1x16xbf16> {
+
+    %c0 = arith.constant 0 : index
+
+  %smem_coop_a = memref.subview %arg0[64, 0][1, 16][1, 1] : memref<256x16xbf16, 3> to memref<1x16xbf16, strided<[16, 1], offset: 1024>, 3>
+
+  //CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %{{.*}} : memref<1x16xbf16, strided<[16, 1], offset: 1024>, 3> -> index
+  //CHECK: %[[C1024:.*]] = arith.constant 1024 : index
+  //CHECK: %[[CAST0:.*]] = arith.index_castui %[[INTPTR]] : index to i32
+  //CHECK: %[[CAST1:.*]] = arith.index_castui %[[C1024]] : index to i32
+  //CHECK: %[[C2:.*]] = arith.constant 2 : i32
+  //CHECK: %[[MUL:.*]] = arith.muli %[[CAST1]], %[[C2]] : i32
+  //CHECK: %{{.*}} = arith.addi %[[CAST0]], %[[MUL]] : i32
+
+  %mdesc_coop_a = xegpu.create_mem_desc %smem_coop_a : memref<1x16xbf16, strided<[16, 1], offset: 1024>, 3> -> !xegpu.mem_desc<1x16xbf16>
+
+  %ret = xegpu.load_matrix%mdesc_coop_a[%c0, %c0]: !xegpu.mem_desc<1x16xbf16>, index, index -> vector<1x16xbf16>
+
+  gpu.return  %ret : vector<1x16xbf16>
+
+  }
+
+
 }

diff  --git a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
index d606cf51435dc..4c84699a069f0 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
@@ -107,3 +107,28 @@ gpu.func @prefetch_memref_src_value_offset(%src: memref<256xf32>, %offset: vecto
   gpu.return
 }
 }
+// -----
+
+gpu.module @test {
+// CHECK-LABEL: @load_gather_from_dyn_memref_subview
+gpu.func @load_gather_from_dyn_memref_subview(%dyn: memref<?xf16>, %offset: vector<1xindex>, %mask: vector<1xi1>, %dst: memref<1xf16>) {
+  %c0 = arith.constant 0 : index
+  %id = gpu.subgroup_id : index
+  %src = memref.subview %dyn[%id][16][1] : memref<?xf16> to memref<16xf16, strided<[1], offset: ?>>
+
+  // CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]], %[[STRIDES:.*]] = memref.extract_strided_metadata %{{.*}} : memref<16xf16, strided<[1], offset: ?>> -> memref<f16>, index, index, index
+  // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE]] : memref<f16> -> index
+  // CHECK: %[[CAST1:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+  // CHECK: %[[CAST2:.*]] = arith.index_castui %[[OFFSET]] : index to i64
+  // CHECK: %[[MUL1:.*]] = arith.muli %[[CAST2]], %{{.*}} : i64
+  // CHECK: %[[ADD1:.*]] = arith.addi %[[CAST1]], %[[MUL1]] : i64
+  // CHECK: %[[MUL2:.*]] = arith.muli %{{.*}}, %{{.*}} : i64
+  // CHECK: %[[ADD2:.*]] = arith.addi %[[ADD1]], %[[MUL2]] : i64
+  // CHECK: %{{.*}} = llvm.inttoptr %[[ADD2]] : i64 to !llvm.ptr<1>
+
+  %0 = xegpu.load %src[%offset], %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+      : memref<16xf16, strided<[1], offset: ?>>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
+  vector.store %0, %dst[%c0] : memref<1xf16>, vector<1xf16>
+  gpu.return
+}
+}


        


More information about the Mlir-commits mailing list