[Mlir-commits] [mlir] 8adcf0a - [MLIR][XeGPU] Support subview memref: handling the base address during xegpu to xevm type conversion (#170541)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Dec 10 08:53:22 PST 2025
Author: Jianhui Li
Date: 2025-12-10T08:53:18-08:00
New Revision: 8adcf0ad5a9ff8649e98120e9101cf48c1585ead
URL: https://github.com/llvm/llvm-project/commit/8adcf0ad5a9ff8649e98120e9101cf48c1585ead
DIFF: https://github.com/llvm/llvm-project/commit/8adcf0ad5a9ff8649e98120e9101cf48c1585ead.diff
LOG: [MLIR][XeGPU] Support subview memref: handling the base address during xegpu to xevm type conversion (#170541)
During the XeGPU-to-XeVM type conversion, a memref is lowered to its
base address. This PR extends the conversion to correctly handle memrefs
that include an offset, such as those generated by memref.subview.
Added:
Modified:
mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 0ecb50e719d58..8a06271eadd84 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -66,6 +66,18 @@ static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
llvm_unreachable("Unknown XeGPU memory space");
}
+/// Checks if the given MemRefType refers to shared memory.
+static bool isSharedMemRef(const MemRefType &memrefTy) {
+ Attribute attr = memrefTy.getMemorySpace();
+ if (!attr)
+ return false;
+ if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
+ return intAttr.getInt() == static_cast<int>(xevm::AddrSpace::SHARED);
+ if (auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr))
+ return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
+ return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
+}
+
// Get same bitwidth flat vector type of new element type.
static VectorType encodeVectorTypeTo(VectorType currentVecType,
Type toElemType) {
@@ -1066,15 +1078,13 @@ struct ConvertXeGPUToXeVMPass
});
typeConverter.addConversion([&](MemRefType type) -> Type {
- if (type.getMemorySpaceAsInt() == 3)
- return IntegerType::get(&getContext(), 32);
- return IntegerType::get(&getContext(), 64);
+ return IntegerType::get(&getContext(), (isSharedMemRef(type) ? 32 : 64));
});
// LLVM type converter puts unrealized casts for the following cases:
// add materialization casts to handle them.
- // Materialization to convert memref to i64
+ // Materialization to convert memref to i64 or i32 depending on global/SLM
auto memrefMaterializationCast = [](OpBuilder &builder, Type type,
ValueRange inputs,
Location loc) -> Value {
@@ -1082,11 +1092,55 @@ struct ConvertXeGPUToXeVMPass
return {};
auto input = inputs.front();
if (auto memrefTy = dyn_cast<MemRefType>(input.getType())) {
+ unsigned rank = memrefTy.getRank();
+ Type indexType = builder.getIndexType();
- Value addr =
- memref::ExtractAlignedPointerAsIndexOp::create(builder, loc, input);
- return arith::IndexCastUIOp::create(builder, loc, type, addr)
- .getResult();
+ int64_t intOffsets;
+ SmallVector<int64_t> intStrides;
+ Value addr;
+ Value offset;
+ if (succeeded(memrefTy.getStridesAndOffset(intStrides, intOffsets)) &&
+ ShapedType::isStatic(intOffsets)) {
+ addr = memref::ExtractAlignedPointerAsIndexOp::create(builder, loc,
+ input);
+ offset = arith::ConstantOp::create(builder, loc,
+ builder.getIndexAttr(intOffsets));
+ } else {
+
+ // Result types: [base_memref, offset, stride0, stride1, ...,
+ // strideN-1, size0, size1, ..., sizeN-1]
+ SmallVector<Type> resultTypes{
+ MemRefType::get({}, memrefTy.getElementType(),
+ MemRefLayoutAttrInterface(),
+ memrefTy.getMemorySpace()),
+ indexType};
+ // strides + sizes
+ resultTypes.append(2 * rank, indexType);
+
+ auto meta = memref::ExtractStridedMetadataOp::create(
+ builder, loc, resultTypes, input);
+
+ addr = memref::ExtractAlignedPointerAsIndexOp::create(
+ builder, loc, meta.getBaseBuffer());
+ offset = meta.getOffset();
+ }
+
+ auto addrCasted =
+ arith::IndexCastUIOp::create(builder, loc, type, addr);
+ auto offsetCasted =
+ arith::IndexCastUIOp::create(builder, loc, type, offset);
+
+ // Compute the final address: base address + byte offset
+ auto byteSize = arith::ConstantOp::create(
+ builder, loc, type,
+ builder.getIntegerAttr(type,
+ memrefTy.getElementTypeBitWidth() / 8));
+ auto byteOffset =
+ arith::MulIOp::create(builder, loc, offsetCasted, byteSize);
+ auto addrWithOffset =
+ arith::AddIOp::create(builder, loc, addrCasted, byteOffset);
+
+ return addrWithOffset.getResult();
}
return {};
};
diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
index 9a1e2cb3c7de0..34654126ce8d2 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
@@ -37,7 +37,7 @@ gpu.module @create_nd_tdesc {
// CHECK: %[[C32_I64_2:.*]] = arith.constant 32 : i64
// CHECK: %[[PITCH2:.*]] = arith.trunci %[[C32_I64_2]] : i64 to i32
// CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64>
- // CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64>
+ // CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2_OFFSET:.*]], %[[VAR14]] [0] : i64 into vector<4xi64>
// CHECK: %[[VAR16:.*]] = vector.bitcast %[[VAR15]] : vector<4xi64> to vector<8xi32>
// CHECK: %[[VAR17:.*]] = vector.insert %[[SHAPE_W2]], %[[VAR16]] [2] : i32 into vector<8xi32>
// CHECK: %[[VAR18:.*]] = vector.insert %[[SHAPE_H2]], %[[VAR17]] [3] : i32 into vector<8xi32>
@@ -55,7 +55,7 @@ gpu.module @create_nd_tdesc {
// CHECK: %[[SHAPE_H3:.*]] = arith.index_cast %[[C64]] : index to i32
// CHECK: %[[PITCH3:.*]] = arith.index_cast %[[C16]] : index to i32
// CHECK: %[[VAR25:.*]] = vector.bitcast %[[CST_3]] : vector<8xi32> to vector<4xi64>
- // CHECK: %[[VAR26:.*]] = vector.insert %[[DYN_ADDR]], %[[VAR25]] [0] : i64 into vector<4xi64>
+ // CHECK: %[[VAR26:.*]] = vector.insert %[[DYN_ADDR_OFFSET:.*]], %[[VAR25]] [0] : i64 into vector<4xi64>
// CHECK: %[[VAR27:.*]] = vector.bitcast %[[VAR26]] : vector<4xi64> to vector<8xi32>
// CHECK: %[[VAR28:.*]] = vector.insert %[[SHAPE_W3]], %[[VAR27]] [2] : i32 into vector<8xi32>
// CHECK: %[[VAR29:.*]] = vector.insert %[[SHAPE_H3]], %[[VAR28]] [3] : i32 into vector<8xi32>
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
index 3a3769f3a4f70..fa683175693be 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
@@ -1,11 +1,19 @@
-// RUN: mlir-opt -split-input-file -convert-xegpu-to-xevm -cse %s | FileCheck %s
+// RUN: mlir-opt -split-input-file -convert-xegpu-to-xevm %s | FileCheck %s
gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
- // e.g. for mem_desc<32x32xf16, @strides=[1, 16]>
+ // e.g. for mem_desc<32x32xf16, @strides=[1, 16]>
// its memory layout tuple is (blocked shape = [1,1,32,32],strides=[1024,1024,32,1])
//CHECK-LABEL: load_store_matrix_plain
gpu.func @load_store_matrix_plain(%arg0: memref<4096xi8, 3>) -> f32 {
+
+ //CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index
+ //CHECK: %[[C0:.*]] = arith.constant 0 : index
+ //CHECK: %[[CAST0:.*]] = arith.index_castui %[[INTPTR]] : index to i32
+ //CHECK: %[[CAST1:.*]] = arith.index_castui %[[C0]] : index to i32
+ //CHECK: %[[C1_I32:.*]] = arith.constant 1 : i32
+ //CHECK: %[[MUL:.*]] = arith.muli %[[CAST1]], %[[C1_I32]] : i32
+ //CHECK: %[[ADD:.*]] = arith.addi %[[CAST0]], %[[MUL]] : i32
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32>
//CHECK: %[[TID:.*]] = gpu.thread_id x
@@ -27,28 +35,33 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
}
//CHECK-LABEL: load_store_matrix_plain_2d_input
- gpu.func @load_store_matrix_plain_2d_input(%arg0: memref<8192xi8, 3>) -> f32 {
+ gpu.func @load_store_matrix_plain_2d_input(%arg0: memref<1024xi8, 3>) -> f32 {
%c0 = arith.constant 0 : index
- %view = memref.view %arg0[%c0][]: memref<8192xi8, 3> to memref<64x32xf32, 3>
+ %view = memref.view %arg0[%c0][]: memref<1024xi8, 3> to memref<64x32xf32, 3>
%subview = memref.subview %view[32, 0] [32, 32] [1, 1] : memref<64x32xf32, 3> to memref<32x32xf32, strided<[32, 1], offset: 1024>, 3>
+ //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[base_buffer:.*]] : memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> -> index
+ //CHECK: %[[ptr_i32:.*]] = arith.index_castui %[[intptr]] : index to i32
+ //CHECK: %[[offset_i32:.*]] = arith.index_castui %[[offset:.*]] : index to i32
+ //CHECK: %[[c4_i32:.*]] = arith.constant 4 : i32
+ //CHECK: %[[mul:.*]] = arith.muli %[[offset_i32]], %[[c4_i32]] : i32
+ //CHECK: %[[add:.*]] = arith.addi %[[ptr_i32]], %[[mul]] : i32
+
%0 = xegpu.create_mem_desc %subview : memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> -> !xegpu.mem_desc<32x32xf32>
//CHECK: %[[TID:.*]] = gpu.thread_id x
//CHECK: %[[C1:.*]] = arith.constant 1 : index
//CHECK: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index
- //CHECK: %[[C4:.*]] = arith.constant 4 : i32
- //CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4]] : i32
+ //CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, {{.*}} : i32
//CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32
%tid_x = gpu.thread_id x
-
%1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> f32
//CHECK: llvm.store {{.*}}, {{.*}} : f32, !llvm.ptr<3>
- xegpu.store_matrix %1, %0[%c0, %tid_x]: f32, !xegpu.mem_desc<32x32xf32>, index, index
+ xegpu.store_matrix %1, %0[%c0, %tid_x]: f32, !xegpu.mem_desc<32x32xf32>, index, index
gpu.return %1: f32
}
@@ -64,9 +77,12 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
//CHECK: %[[c13:.*]] = arith.constant 13 : index
//CHECK: %[[c16:.*]] = arith.constant 16 : index
//CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c13]], %[[c16]] : index
- //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c13]], %[[c16]] : index
- //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
- //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
+ //CHECK: %[[c16_0:.*]] = arith.constant 16 : index
+ //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c13]], %[[c16_0]] : index
+ //CHECK: %[[c16_1:.*]] = arith.constant 16 : index
+ //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16_1]] : index
+ //CHECK: %[[c16_2:.*]] = arith.constant 16 : index
+ //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16_2]] : index
//CHECK: %[[c0:.*]] = arith.constant 0 : index
//CHECK: %[[c256:.*]] = arith.constant 256 : index
//CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
@@ -77,11 +93,15 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
//CHECK: %[[c1:.*]] = arith.constant 1 : index
//CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index
//CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
- //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index
+ //CHECK: %[[c16_4:.*]] = arith.constant 16 : index
+ //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16_4]] : index
//CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
-
- //CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> f16
-
+ //CHECK: %[[cast:.*]] = arith.index_castui %[[add3]] : index to i32
+ //CHECK: %[[c2_i32:.*]] = arith.constant 2 : i32
+ //CHECK: %[[byte_offset:.*]] = arith.muli %[[cast]], %[[c2_i32]] : i32
+ //CHECK: %[[final_ptr:.*]] = arith.addi {{.*}}, %[[byte_offset]] : i32
+ //CHECK: %[[ptr:.*]] = llvm.inttoptr %[[final_ptr]] : i32 to !llvm.ptr<3>
+ //CHECK: %[[loaded:.*]] = llvm.load %[[ptr]] : !llvm.ptr<3> -> f16
%tid_x = gpu.thread_id x
%c13 = arith.constant 13 : index
@@ -93,14 +113,17 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
gpu.return %1: f16
}
-
// e.g. for mem_desc<32x64xf16, @block=[16, 16]>
// its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
//CHECK-LABEL: load_store_matrix_blocked_nostride
gpu.func @load_store_matrix_blocked_nostride(%arg0: memref<4096xi8, 3>) -> f16 {
-
//CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index
- //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
+ //CHECK: %[[c0:.*]] = arith.constant 0 : index
+ //CHECK: %[[cast0:.*]] = arith.index_castui %[[intptr]] : index to i32
+ //CHECK: %[[cast1:.*]] = arith.index_castui %[[c0]] : index to i32
+ //CHECK: %[[c1_i32:.*]] = arith.constant 1 : i32
+ //CHECK: %[[mul:.*]] = arith.muli %[[cast1]], %[[c1_i32]] : i32
+ //CHECK: %[[add:.*]] = arith.addi %[[cast0]], %[[mul]] : i32
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
//CHECK: %[[tid_x:.*]] = gpu.thread_id x
@@ -110,9 +133,12 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
//CHECK: %[[c16:.*]] = arith.constant 16 : index
//CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index
- //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index
- //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
- //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
+ //CHECK: %[[c16_0:.*]] = arith.constant 16 : index
+ //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16_0]] : index
+ //CHECK: %[[c16_1:.*]] = arith.constant 16 : index
+ //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16_1]] : index
+ //CHECK: %[[c16_2:.*]] = arith.constant 16 : index
+ //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16_2]] : index
//CHECK: %[[c0:.*]] = arith.constant 0 : index
//CHECK: %[[c1024:.*]] = arith.constant 1024 : index
//CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c1024]] : index
@@ -120,12 +146,18 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
//CHECK: %[[c256:.*]] = arith.constant 256 : index
//CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c256]] : index
//CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
- //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c16]] : index
+ //CHECK: %[[c16_4:.*]] = arith.constant 16 : index
+ //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c16_4]] : index
//CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
//CHECK: %[[c1:.*]] = arith.constant 1 : index
//CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c1]] : index
//CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
- //CHECK: %[[loaded:.*]] = llvm.load {{.*}} : !llvm.ptr<3> -> f16
+ //CHECK: %[[cast:.*]] = arith.index_castui %[[add3]] : index to i32
+ //CHECK: %[[c2_i32:.*]] = arith.constant 2 : i32
+ //CHECK: %[[byte_offset:.*]] = arith.muli %[[cast]], %[[c2_i32]] : i32
+ //CHECK: %[[final_ptr:.*]] = arith.addi {{.*}}, %[[byte_offset]] : i32
+ //CHECK: %[[ptr:.*]] = llvm.inttoptr %[[final_ptr]] : i32 to !llvm.ptr<3>
+ //CHECK: %[[loaded:.*]] = llvm.load %[[ptr]] : !llvm.ptr<3> -> f16
%1 = xegpu.load_matrix %0[%c19, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> f16
//CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3>
@@ -143,10 +175,14 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
//CHECK: %[[tid_x:.*]] = gpu.thread_id x
//CHECK: %[[c16:.*]] = arith.constant 16 : index
- //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
- //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
- //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
- //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
+ //CHECK: %[[c16_0:.*]] = arith.constant 16 : index
+ //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c16]], %[[c16_0]] : index
+ //CHECK: %[[c16_1:.*]] = arith.constant 16 : index
+ //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c16]], %[[c16_1]] : index
+ //CHECK: %[[c16_2:.*]] = arith.constant 16 : index
+ //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16_2]] : index
+ //CHECK: %[[c16_3:.*]] = arith.constant 16 : index
+ //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16_3]] : index
//CHECK: %[[c0:.*]] = arith.constant 0 : index
//CHECK: %[[c256:.*]] = arith.constant 256 : index
//CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
@@ -157,10 +193,15 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
//CHECK: %[[c1:.*]] = arith.constant 1 : index
//CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index
//CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
- //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index
+ //CHECK: %[[c16_5:.*]] = arith.constant 16 : index
+ //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16_5]] : index
//CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
-
- //CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> vector<8xf16>
+ //CHECK: %[[cast:.*]] = arith.index_castui %[[add3]] : index to i32
+ //CHECK: %[[c2_i32:.*]] = arith.constant 2 : i32
+ //CHECK: %[[byte_offset:.*]] = arith.muli %[[cast]], %[[c2_i32]] : i32
+ //CHECK: %[[final_ptr:.*]] = arith.addi {{.*}}, %[[byte_offset]] : i32
+ //CHECK: %[[ptr:.*]] = llvm.inttoptr %[[final_ptr]] : i32 to !llvm.ptr<3>
+ //CHECK: %[[loaded:.*]] = llvm.load %[[ptr]] : !llvm.ptr<3> -> vector<8xf16>
%tid_x = gpu.thread_id x
%c16 = arith.constant 16 : index
@@ -179,42 +220,52 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
gpu.func @load_store_matrix_blocked_subgroupblockio(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
//CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index
- //CHECK: %[[basePtrI32:.*]] = arith.index_castui %[[intptr]] : index to i32
- %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
-
+ //CHECK: %[[c0:.*]] = arith.constant 0 : index
+ //CHECK: %[[cast0:.*]] = arith.index_castui %[[intptr]] : index to i32
+ //CHECK: %[[cast1:.*]] = arith.index_castui %[[c0]] : index to i32
+ //CHECK: %[[c1_i32:.*]] = arith.constant 1 : i32
+ //CHECK: %[[mul:.*]] = arith.muli %[[cast1]], %[[c1_i32]] : i32
+ //CHECK: %[[add:.*]] = arith.addi %[[cast0]], %[[mul]] : i32
//CHECK: %[[c16:.*]] = arith.constant 16 : index
//CHECK: %[[c48:.*]] = arith.constant 48 : index
- %c16 = arith.constant 16 : index
- %c48 = arith.constant 48 : index
-
- //CHECK: %[[offset0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
- //CHECK: %[[offset1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
- //CHECK: %[[offset2:.*]] = arith.divsi %[[c48]], %[[c16]] : index
- //CHECK: %[[offset3:.*]] = arith.remsi %[[c48]], %[[c16]] : index
- //CHECK: %[[c0:.*]] = arith.constant 0 : index
+ //CHECK: %[[c16_0:.*]] = arith.constant 16 : index
+ //CHECK: %[[divsi0:.*]] = arith.divsi %[[c16]], %[[c16_0]] : index
+ //CHECK: %[[c16_1:.*]] = arith.constant 16 : index
+ //CHECK: %[[remsi0:.*]] = arith.remsi %[[c16]], %[[c16_1]] : index
+ //CHECK: %[[c16_2:.*]] = arith.constant 16 : index
+ //CHECK: %[[divsi1:.*]] = arith.divsi %[[c48]], %[[c16_2]] : index
+ //CHECK: %[[c16_3:.*]] = arith.constant 16 : index
+ //CHECK: %[[remsi1:.*]] = arith.remsi %[[c48]], %[[c16_3]] : index
+ //CHECK: %[[c0_4:.*]] = arith.constant 0 : index
//CHECK: %[[c1024:.*]] = arith.constant 1024 : index
- //CHECK: %[[mul0:.*]] = arith.muli %[[offset0]], %[[c1024]] : index
- //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
+ //CHECK: %[[mul0:.*]] = arith.muli %[[divsi0]], %[[c1024]] : index
+ //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0_4]] : index
//CHECK: %[[c256:.*]] = arith.constant 256 : index
- //CHECK: %[[mul1:.*]] = arith.muli %[[offset2]], %[[c256]] : index
+ //CHECK: %[[mul1:.*]] = arith.muli %[[divsi1]], %[[c256]] : index
//CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
- //CHECK: %[[mul2:.*]] = arith.muli %[[offset1]], %[[c16]] : index
+ //CHECK: %[[c16_5:.*]] = arith.constant 16 : index
+ //CHECK: %[[mul2:.*]] = arith.muli %[[remsi0]], %[[c16_5]] : index
//CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
//CHECK: %[[c1:.*]] = arith.constant 1 : index
- //CHECK: %[[mul3:.*]] = arith.muli %[[offset3]], %[[c1]] : index
- //CHECK: %[[linearOffset:.*]] = arith.addi %[[mul3]], %[[add2]] : index
- //CHECK: %[[linearOffsetI64:.*]] = arith.index_castui %[[linearOffset]] : index to i32
- //CHECK: %[[c2:.*]] = arith.constant 2 : i32
- //CHECK: %[[byteOffset:.*]] = arith.muli %[[linearOffsetI64]], %[[c2]] : i32
- //CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI32]], %[[byteOffset]] : i32
- //CHECK: %[[ptr:.*]] = llvm.inttoptr %[[finalPtr]] : i32 to !llvm.ptr<3>
- //CHECK: %[[loadedI16:.*]] = xevm.blockload %[[ptr]] : (!llvm.ptr<3>) -> vector<8xi16>
- //CHECK: %[[loaded:.*]] = vector.bitcast %[[loadedI16]] : vector<8xi16> to vector<8xf16>
+ //CHECK: %[[mul3:.*]] = arith.muli %[[remsi1]], %[[c1]] : index
+ //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
+ //CHECK: %[[cast:.*]] = arith.index_castui %[[add3]] : index to i32
+ //CHECK: %[[c2_i32:.*]] = arith.constant 2 : i32
+ //CHECK: %[[byte_offset:.*]] = arith.muli %[[cast]], %[[c2_i32]] : i32
+ //CHECK: %[[final_ptr:.*]] = arith.addi %[[add]], %[[byte_offset]] : i32
+ //CHECK: %[[ptr:.*]] = llvm.inttoptr %[[final_ptr]] : i32 to !llvm.ptr<3>
+ //CHECK: %[[blockload:.*]] = xevm.blockload %[[ptr]] : (!llvm.ptr<3>) -> vector<8xi16>
+ //CHECK: %[[loaded:.*]] = vector.bitcast %[[blockload]] : vector<8xi16> to vector<8xf16>
+
+ %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
+
+ %c16 = arith.constant 16 : index
+ %c48 = arith.constant 48 : index
%1 = xegpu.load_matrix %0[%c16, %c48] {subgroup_block_io}: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> vector<8xf16>
//CHECK: %[[storeDataI16:.*]] = vector.bitcast %[[loaded]] : vector<8xf16> to vector<8xi16>
- //CHECK: xevm.blockstore %[[ptr]], %[[storeDataI16]] : (!llvm.ptr<3>, vector<8xi16>)
+ //CHECK: xevm.blockstore %[[ptr2:.*]], %[[storeDataI16]] : (!llvm.ptr<3>, vector<8xi16>)
xegpu.store_matrix %1, %0[%c16, %c48] {subgroup_block_io}: vector<8xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index
@@ -232,4 +283,29 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
xegpu.store_matrix %loaded_2, %matrixdesc[16,0] : vector<1x16xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
gpu.return
}
+
+ //CHECK-LABEL: create_memdesc_from_subview
+ gpu.func @create_memdesc_from_subview(%arg0: memref<256x16xbf16, 3>) -> vector<1x16xbf16> {
+
+ %c0 = arith.constant 0 : index
+
+ %smem_coop_a = memref.subview %arg0[64, 0][1, 16][1, 1] : memref<256x16xbf16, 3> to memref<1x16xbf16, strided<[16, 1], offset: 1024>, 3>
+
+ //CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %{{.*}} : memref<1x16xbf16, strided<[16, 1], offset: 1024>, 3> -> index
+ //CHECK: %[[C1024:.*]] = arith.constant 1024 : index
+ //CHECK: %[[CAST0:.*]] = arith.index_castui %[[INTPTR]] : index to i32
+ //CHECK: %[[CAST1:.*]] = arith.index_castui %[[C1024]] : index to i32
+ //CHECK: %[[C2:.*]] = arith.constant 2 : i32
+ //CHECK: %[[MUL:.*]] = arith.muli %[[CAST1]], %[[C2]] : i32
+ //CHECK: %{{.*}} = arith.addi %[[CAST0]], %[[MUL]] : i32
+
+ %mdesc_coop_a = xegpu.create_mem_desc %smem_coop_a : memref<1x16xbf16, strided<[16, 1], offset: 1024>, 3> -> !xegpu.mem_desc<1x16xbf16>
+
+ %ret = xegpu.load_matrix%mdesc_coop_a[%c0, %c0]: !xegpu.mem_desc<1x16xbf16>, index, index -> vector<1x16xbf16>
+
+ gpu.return %ret : vector<1x16xbf16>
+
+ }
+
+
}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
index d606cf51435dc..4c84699a069f0 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
@@ -107,3 +107,28 @@ gpu.func @prefetch_memref_src_value_offset(%src: memref<256xf32>, %offset: vecto
gpu.return
}
}
+// -----
+
+gpu.module @test {
+// CHECK-LABEL: @load_gather_from_dyn_memref_subview
+gpu.func @load_gather_from_dyn_memref_subview(%dyn: memref<?xf16>, %offset: vector<1xindex>, %mask: vector<1xi1>, %dst: memref<1xf16>) {
+ %c0 = arith.constant 0 : index
+ %id = gpu.subgroup_id : index
+ %src = memref.subview %dyn[%id][16][1] : memref<?xf16> to memref<16xf16, strided<[1], offset: ?>>
+
+ // CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]], %[[STRIDES:.*]] = memref.extract_strided_metadata %{{.*}} : memref<16xf16, strided<[1], offset: ?>> -> memref<f16>, index, index, index
+ // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE]] : memref<f16> -> index
+ // CHECK: %[[CAST1:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+ // CHECK: %[[CAST2:.*]] = arith.index_castui %[[OFFSET]] : index to i64
+ // CHECK: %[[MUL1:.*]] = arith.muli %[[CAST2]], %{{.*}} : i64
+ // CHECK: %[[ADD1:.*]] = arith.addi %[[CAST1]], %[[MUL1]] : i64
+ // CHECK: %[[MUL2:.*]] = arith.muli %{{.*}}, %{{.*}} : i64
+ // CHECK: %[[ADD2:.*]] = arith.addi %[[ADD1]], %[[MUL2]] : i64
+ // CHECK: %{{.*}} = llvm.inttoptr %[[ADD2]] : i64 to !llvm.ptr<1>
+
+ %0 = xegpu.load %src[%offset], %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : memref<16xf16, strided<[1], offset: ?>>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
+ vector.store %0, %dst[%c0] : memref<1xf16>, vector<1xf16>
+ gpu.return
+}
+}
More information about the Mlir-commits
mailing list