[Mlir-commits] [mlir] 447af32 - [MLIR][XeGPU][XeVM] create_nd_tdesc: use correct pitch from strides. (#170384)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Dec 8 08:15:48 PST 2025
Author: Sang Ik Lee
Date: 2025-12-08T08:15:44-08:00
New Revision: 447af32fbb8c9522bc01948caadd1ae562c01373
URL: https://github.com/llvm/llvm-project/commit/447af32fbb8c9522bc01948caadd1ae562c01373
DIFF: https://github.com/llvm/llvm-project/commit/447af32fbb8c9522bc01948caadd1ae562c01373.diff
LOG: [MLIR][XeGPU][XeVM] create_nd_tdesc: use correct pitch from strides. (#170384)
Base memory pitch should be derived from base stride, not base width.
Remove offset fields from tensor descriptor payload and add pitch field.
Added:
Modified:
mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 7f1ec17ce0ae8..9c99a24bea8cd 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -50,11 +50,10 @@ static constexpr int32_t executionSize{16};
// Offsets to individual fields of the 8xi32 layout nd tensor descriptor.
enum class NdTdescOffset : uint32_t {
- BasePtr = 0, // Base pointer (i64)
- BaseShapeW = 2, // Base shape width (i32)
- BaseShapeH = 3, // Base shape height (i32)
- TensorOffsetW = 4, // Tensor offset W (i32)
- TensorOffsetH = 5 // Tensor offset H (i32)
+ BasePtr = 0, // Base pointer (i64)
+ BaseShapeW = 2, // Base shape width (i32)
+ BaseShapeH = 3, // Base shape height (i32)
+ BasePitch = 4, // Base pitch (i32)
};
static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
@@ -179,11 +178,10 @@ class CreateNdDescToXeVMPattern
Value baseAddr;
Value baseShapeW;
Value baseShapeH;
- Value offsetW;
- Value offsetH;
// Source can be a memref or a pointer (ui64, ui32, i64 or i32).
SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
+ SmallVector<OpFoldResult> mixedStrides = op.getMixedStrides();
// Descriptor shape is expected to be 2D.
int64_t rank = mixedSizes.size();
auto sourceTy = source.getType();
@@ -216,12 +214,11 @@ class CreateNdDescToXeVMPattern
val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val);
return val;
};
- // Offsets are not supported (0 is used).
- offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
- offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
// Get shape values from op fold results.
baseShapeW = createOffset(mixedSizes, 1);
baseShapeH = createOffset(mixedSizes, 0);
+ // Get pitch value from op fold results.
+ Value basePitch = createOffset(mixedStrides, 0);
// Populate payload.
Value payLoadAsI64 =
vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
@@ -235,12 +232,9 @@ class CreateNdDescToXeVMPattern
payload =
vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
static_cast<int>(NdTdescOffset::BaseShapeH));
- payload = vector::InsertOp::create(
- rewriter, loc, offsetW, payload,
- static_cast<int>(NdTdescOffset::TensorOffsetW));
- payload = vector::InsertOp::create(
- rewriter, loc, offsetH, payload,
- static_cast<int>(NdTdescOffset::TensorOffsetH));
+ payload =
+ vector::InsertOp::create(rewriter, loc, basePitch, payload,
+ static_cast<int>(NdTdescOffset::BasePitch));
rewriter.replaceOp(op, payload);
return success();
}
@@ -289,6 +283,8 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
Value baseShapeH = vector::ExtractOp::create(
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
+ Value basePitch = vector::ExtractOp::create(
+ rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BasePitch));
// Offsets are provided by the op.
// convert them to i32.
Value offsetW =
@@ -303,8 +299,11 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
Value basePtrLLVM =
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
// Compute width in bytes.
- Value surfaceW =
+ Value baseWidthByte =
arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
+ // Compute pitch in bytes.
+ Value basePitchByte =
+ arith::MulIOp::create(rewriter, loc, basePitch, elemByteSize);
// Get tile width from the tensor descriptor type.
auto tileW = tdescTy.getDimSize(tileRank - 1);
@@ -331,8 +330,8 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
auto storeCacheControl =
translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
xevm::BlockStore2dOp::create(
- rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
- offsetH, elemBitSize, tileW, tileH, src,
+ rewriter, loc, basePtrLLVM, baseWidthByte, baseShapeH,
+ basePitchByte, offsetW, offsetH, elemBitSize, tileW, tileH, src,
xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
rewriter.eraseOp(op);
} else {
@@ -340,9 +339,9 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
xevm::BlockPrefetch2dOp::create(
- rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW,
- offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
- xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
+ rewriter, loc, basePtrLLVM, baseWidthByte, baseShapeH,
+ basePitchByte, offsetW, offsetH, elemBitSize, tileW, tileH,
+ vblocks, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
rewriter.eraseOp(op);
} else {
VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
@@ -355,9 +354,9 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
: rewriter.getIntegerType(elemBitSize));
Value resultFlatVec = xevm::BlockLoad2dOp::create(
- rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH,
- surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
- transpose, vnni,
+ rewriter, loc, loadedTy, basePtrLLVM, baseWidthByte, baseShapeH,
+ basePitchByte, offsetW, offsetH, elemBitSize, tileW, tileH,
+ vblocks, transpose, vnni,
xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
resultFlatVec = vector::BitCastOp::create(
rewriter, loc,
diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
index 8b87b791c9fd3..9a1e2cb3c7de0 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
@@ -8,21 +8,19 @@ gpu.module @create_nd_tdesc {
gpu.func @create_nd_tdesc(%src: memref<16x32xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index,
%stride1: index, %stride2: index, %offset1: index, %offset2: index, %dyn: memref<?x?xf16>) kernel {
// CHECK: %[[INTPTR_5:.*]] = memref.extract_aligned_pointer_as_index %[[DYN]] : memref<?x?xf16> -> index
- // CHECK: %[[VAR23:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64
+ // CHECK: %[[DYN_ADDR:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64
// CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index
// CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
- // CHECK: %[[OFFSET_W:.*]] = arith.constant 0 : i32
- // CHECK: %[[OFFSET_H:.*]] = arith.constant 0 : i32
// CHECK: %[[SHAPE_W:.*]] = arith.index_cast %[[ARG3]] : index to i32
// CHECK: %[[SHAPE_H:.*]] = arith.index_cast %[[ARG2]] : index to i32
+ // CHECK: %[[PITCH:.*]] = arith.index_cast %[[ARG4]] : index to i32
// CHECK: %[[VAR6:.*]] = vector.bitcast %[[CST]] : vector<8xi32> to vector<4xi64>
// CHECK: %[[VAR7:.*]] = vector.insert %[[BASE_ADDR]], %[[VAR6]] [0] : i64 into vector<4xi64>
// CHECK: %[[VAR8:.*]] = vector.bitcast %[[VAR7]] : vector<4xi64> to vector<8xi32>
// CHECK: %[[VAR9:.*]] = vector.insert %[[SHAPE_W]], %[[VAR8]] [2] : i32 into vector<8xi32>
// CHECK: %[[VAR10:.*]] = vector.insert %[[SHAPE_H]], %[[VAR9]] [3] : i32 into vector<8xi32>
- // CHECK: %[[VAR11:.*]] = vector.insert %[[OFFSET_W]], %[[VAR10]] [4] : i32 into vector<8xi32>
- // CHECK: %[[VAR12:.*]] = vector.insert %[[OFFSET_H]], %[[VAR11]] [5] : i32 into vector<8xi32>
+ // CHECK: %[[VAR11:.*]] = vector.insert %[[PITCH]], %[[VAR10]] [4] : i32 into vector<8xi32>
%ptr_tdesc = xegpu.create_nd_tdesc %ptr, shape:[%shape1, %shape2], strides:[%stride1, %stride2]
: ui64 -> !xegpu.tensor_desc<8x16xf32>
@@ -32,19 +30,18 @@ gpu.module @create_nd_tdesc {
// CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index
// CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
// CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32>
- // CHECK: %[[OFFSET_W2:.*]] = arith.constant 0 : i32
- // CHECK: %[[OFFSET_H2:.*]] = arith.constant 0 : i32
// CHECK: %[[C32_I64:.*]] = arith.constant 32 : i64
// CHECK: %[[SHAPE_W2:.*]] = arith.trunci %[[C32_I64]] : i64 to i32
// CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64
// CHECK: %[[SHAPE_H2:.*]] = arith.trunci %[[C16_I64]] : i64 to i32
+ // CHECK: %[[C32_I64_2:.*]] = arith.constant 32 : i64
+ // CHECK: %[[PITCH2:.*]] = arith.trunci %[[C32_I64_2]] : i64 to i32
// CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64>
// CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64>
// CHECK: %[[VAR16:.*]] = vector.bitcast %[[VAR15]] : vector<4xi64> to vector<8xi32>
// CHECK: %[[VAR17:.*]] = vector.insert %[[SHAPE_W2]], %[[VAR16]] [2] : i32 into vector<8xi32>
// CHECK: %[[VAR18:.*]] = vector.insert %[[SHAPE_H2]], %[[VAR17]] [3] : i32 into vector<8xi32>
- // CHECK: %[[VAR19:.*]] = vector.insert %[[OFFSET_W2]], %[[VAR18]] [4] : i32 into vector<8xi32>
- // CHECK: %[[PAYLOAD:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32>
+ // CHECK: %[[VAR19:.*]] = vector.insert %[[PITCH2]], %[[VAR18]] [4] : i32 into vector<8xi32>
%src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK: %[[C1:.*]] = arith.constant 1 : index
@@ -53,18 +50,16 @@ gpu.module @create_nd_tdesc {
%size_x = arith.constant 64 : index
// CHECK: %[[C16:.*]] = arith.constant 16 : index
%BLOCK_DMODEL = arith.constant 16 : index
- // CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32>
- // CHECK: %[[C0_I32_6:.*]] = arith.constant 0 : i32
- // CHECK: %[[C0_I32_7:.*]] = arith.constant 0 : i32
- // CHECK: %[[VAR21:.*]] = arith.index_cast %[[C16]] : index to i32
- // CHECK: %[[VAR22:.*]] = arith.index_cast %[[C64]] : index to i32
- // CHECK: %[[VAR24:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64>
- // CHECK: %[[VAR25:.*]] = vector.insert %[[VAR23]], %[[VAR24]] [0] : i64 into vector<4xi64>
- // CHECK: %[[VAR26:.*]] = vector.bitcast %[[VAR25]] : vector<4xi64> to vector<8xi32>
- // CHECK: %[[VAR27:.*]] = vector.insert %[[VAR21]], %[[VAR26]] [2] : i32 into vector<8xi32>
- // CHECK: %[[VAR28:.*]] = vector.insert %[[VAR22]], %[[VAR27]] [3] : i32 into vector<8xi32>
- // CHECK: %[[VAR29:.*]] = vector.insert %[[C0_I32_6]], %[[VAR28]] [4] : i32 into vector<8xi32>
- // CHECK: %[[VAR30:.*]] = vector.insert %[[C0_I32_7]], %[[VAR29]] [5] : i32 into vector<8xi32>
+ // CHECK: %[[CST_3:.*]] = arith.constant dense<0> : vector<8xi32>
+ // CHECK: %[[SHAPE_W3:.*]] = arith.index_cast %[[C16]] : index to i32
+ // CHECK: %[[SHAPE_H3:.*]] = arith.index_cast %[[C64]] : index to i32
+ // CHECK: %[[PITCH3:.*]] = arith.index_cast %[[C16]] : index to i32
+ // CHECK: %[[VAR25:.*]] = vector.bitcast %[[CST_3]] : vector<8xi32> to vector<4xi64>
+ // CHECK: %[[VAR26:.*]] = vector.insert %[[DYN_ADDR]], %[[VAR25]] [0] : i64 into vector<4xi64>
+ // CHECK: %[[VAR27:.*]] = vector.bitcast %[[VAR26]] : vector<4xi64> to vector<8xi32>
+ // CHECK: %[[VAR28:.*]] = vector.insert %[[SHAPE_W3]], %[[VAR27]] [2] : i32 into vector<8xi32>
+ // CHECK: %[[VAR29:.*]] = vector.insert %[[SHAPE_H3]], %[[VAR28]] [3] : i32 into vector<8xi32>
+ // CHECK: %[[VAR30:.*]] = vector.insert %[[PITCH3]], %[[VAR29]] [4] : i32 into vector<8xi32>
%dyn_tdesc = xegpu.create_nd_tdesc %dyn, shape: [%size_x, %BLOCK_DMODEL], strides: [%BLOCK_DMODEL, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<16x16xf16>
gpu.return
}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
index afeae8be24b72..4c73c9c238b6e 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
@@ -1,78 +1,32 @@
-// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
+// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s
gpu.module @load_store_check {
// CHECK-LABEL: gpu.func @load_store(
- // CHECK-SAME: %[[ARG0:.*]]: memref<8x16xf32, 1>, %[[ARG1:.*]]: memref<8x16xf32, 1>) kernel {
gpu.func @load_store(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel {
+ // CHECK: %[[W_P_BYTES:.*]] = arith.constant 64 : i32
+ // CHECK: %[[ZERO:.*]] = arith.constant 0 : i32
+ // CHECK: %[[H:.*]] = arith.constant 8 : i32
%srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32>
%dstte = memref.memory_space_cast %dst : memref<8x16xf32, 1> to memref<8x16xf32>
- // CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<8x16xf32, 1> to memref<8x16xf32>
- // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32> -> index
- // CHECK: %[[LD_PTR_AS_I64:.*]] = arith.index_castui %[[INTPTR]] : index to i64
- // CHECK: %[[MEMSPACECAST_0:.*]] = memref.memory_space_cast %[[ARG1]] : memref<8x16xf32, 1> to memref<8x16xf32>
- // CHECK: %[[INTPTR_1:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST_0]] : memref<8x16xf32> -> index
- // CHECK: %[[ST_PTR_AS_I64:.*]] = arith.index_castui %[[INTPTR_1]] : index to i64
- // CHECK: %[[LD_CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64>
- // CHECK: %[[LD_DESC_0:.*]] = vector.insert %[[LD_PTR_AS_I64]], %[[LD_CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
- // CHECK: %[[LD_DESC_1:.*]] = vector.bitcast %[[LD_DESC_0]] : vector<4xi64> to vector<8xi32>
- // CHECK: %[[LD_DESC_2:.*]] = vector.insert {{.*}}, %[[LD_DESC_1]] [2] : i32 into vector<8xi32>
- // CHECK: %[[LD_DESC_3:.*]] = vector.insert {{.*}}, %[[LD_DESC_2]] [3] : i32 into vector<8xi32>
- // CHECK: %[[LD_DESC_4:.*]] = vector.insert {{.*}}, %[[LD_DESC_3]] [4] : i32 into vector<8xi32>
- // CHECK: %[[LD_DESC:.*]] = vector.insert {{.*}}, %[[LD_DESC_4]] [5] : i32 into vector<8xi32>
%src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-
- //CHECK: %[[LD_SIZEOF_F32:.*]] = arith.constant 4 : i32
- //CHECK: %[[LD_DESC_I64:.*]] = vector.bitcast %[[LD_DESC]] : vector<8xi32> to vector<4xi64>
- //CHECK: %[[LD_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64>
- //CHECK: %[[LD_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32>
- //CHECK: %[[LD_BASE_H:.*]] = vector.extract %[[LD_DESC]][3] : i32 from vector<8xi32>
- //CHECK: %[[LD_TILE_W64:.*]] = arith.constant 0 : i64
- //CHECK: %[[LD_TILE_W:.*]] = arith.trunci %[[LD_TILE_W64]] : i64 to i32
- //CHECK: %[[LD_TILE_H64:.*]] = arith.constant 0 : i64
- //CHECK: %[[LD_TILE_H:.*]] = arith.trunci %[[LD_TILE_H64]] : i64 to i32
- //CHECK: %[[LD_LLVMPTR:.*]] = llvm.inttoptr %[[LD_INTPTR]] : i64 to !llvm.ptr<1>
- //CHECK: %[[LD_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[LD_BASE_W]], %[[LD_SIZEOF_F32]] : i32
- //CHECK: %[[LD_LOADED_I32:.*]] = xevm.blockload2d %[[LD_LLVMPTR]], %[[LD_BASE_ROW_IN_BYTES]],
- //CHECK-SAME: %[[LD_BASE_H]], %[[LD_BASE_ROW_IN_BYTES]], %[[LD_TILE_W]], %[[LD_TILE_H]]
+ //CHECK: %[[LD_LOADED_I32:.*]] = xevm.blockload2d %{{.*}}, %[[W_P_BYTES]], %[[H]], %[[W_P_BYTES]], %[[ZERO]], %[[ZERO]]
//CHECK-SAME: <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
//CHECK-SAME: pack_register = false, tile_height = 8 : i32, tile_width = 16 : i32, transpose = false,
//CHECK-SAME: v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
%loaded = xegpu.load_nd %src_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
: !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
- //CHECK: %[[LD_LOADED_F32:.*]] = vector.bitcast %[[LD_LOADED_I32]] : vector<8xi32> to vector<8xf32>
%tid_x = gpu.thread_id x
%tid_x_i32 = arith.index_cast %tid_x : index to i32
%tid_x_f32 = arith.sitofp %tid_x_i32 : i32 to f32
- //CHECK: %[[LOADED_F32_MODIFIED:.*]] = vector.insert %{{.*}}, %[[LD_LOADED_F32]] [0] : f32 into vector<8xf32>
%loaded_modified = vector.insert %tid_x_f32, %loaded[0] : f32 into vector<8xf32>
- // CHECK: %[[CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64>
- // CHECK: %[[DESC_0:.*]] = vector.insert %[[ST_PTR_AS_I64]], %[[CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
- // CHECK: %[[DESC_1:.*]] = vector.bitcast %[[DESC_0]] : vector<4xi64> to vector<8xi32>
- // CHECK: %[[DESC_2:.*]] = vector.insert {{.*}}, %[[DESC_1]] [2] : i32 into vector<8xi32>
- // CHECK: %[[DESC_3:.*]] = vector.insert {{.*}}, %[[DESC_2]] [3] : i32 into vector<8xi32>
- // CHECK: %[[DESC_4:.*]] = vector.insert {{.*}}, %[[DESC_3]] [4] : i32 into vector<8xi32>
- // CHECK: %[[DESC:.*]] = vector.insert {{.*}}, %[[DESC_4]] [5] : i32 into vector<8xi32>
%dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
- //CHECK: %[[SIZEOF_F32:.*]] = arith.constant 4 : i32
- //CHECK: %[[DESC_I64:.*]] = vector.bitcast %[[DESC]] : vector<8xi32> to vector<4xi64>
- //CHECK: %[[INTPTR:.*]] = vector.extract %[[DESC_I64]][0] : i64 from vector<4xi64>
- //CHECK: %[[BASE_W:.*]] = vector.extract %[[DESC]][2] : i32 from vector<8xi32>
- //CHECK: %[[BASE_H:.*]] = vector.extract %[[DESC]][3] : i32 from vector<8xi32>
- //CHECK: %[[TILE_W64:.*]] = arith.constant 0 : i64
- //CHECK: %[[TILE_W:.*]] = arith.trunci %[[TILE_W64]] : i64 to i32
- //CHECK: %[[TILE_H64:.*]] = arith.constant 0 : i64
- //CHECK: %[[TILE_H:.*]] = arith.trunci %[[TILE_H64]] : i64 to i32
- //CHECK: %[[LLVMPTR:.*]] = llvm.inttoptr %[[INTPTR]] : i64 to !llvm.ptr<1>
- //CHECK: %[[BASE_ROW_IN_BYTES:.*]] = arith.muli %[[BASE_W]], %[[SIZEOF_F32]] : i32
- //CHECK: %[[FLAT_VALUE_I32:.*]] = vector.bitcast %[[LOADED_F32_MODIFIED]] : vector<8xf32> to vector<8xi32>
- //CHECK: xevm.blockstore2d %[[LLVMPTR]], %[[BASE_ROW_IN_BYTES]], %[[BASE_H]], %[[BASE_ROW_IN_BYTES]],
- //CHECK-SAME: %[[TILE_W]], %[[TILE_H]], %[[FLAT_VALUE_I32]]
- //CHECK-SAME: <{cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
+ //CHECK: xevm.blockstore2d %{{.*}}, %[[W_P_BYTES]], %[[H]], %[[W_P_BYTES]], %[[ZERO]], %[[ZERO]], %{{.*}} <{
+ //CHECK-SAME: cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
//CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
xegpu.store_nd %loaded_modified, %dst_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
: vector<8xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
diff --git a/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir
index e4b303087ea9b..43df721fb77a0 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir
@@ -3,27 +3,16 @@
gpu.module @prefetch_nd_check {
// CHECK-LABEL: gpu.func @prefetch_nd
gpu.func @prefetch_nd(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel {
- // CHECK: %[[PREF_BASE_ROW_IN_BYTES:.*]] = arith.constant 64 : i32
- // CHECK: %[[LD_CREATE_DESC_I64:.*]] = arith.constant dense<0> : vector<4xi64>
- // CHECK: %[[PREF_BASE_H:.*]] = arith.constant 8 : i32
- // CHECK: %[[PREF_BASE_W:.*]] = arith.constant 16 : i32
+ // CHECK: %[[BASE_WIDTH_PITCH_BYTES:.*]] = arith.constant 64 : i32
// CHECK: %[[OFFSET_ZERO:.*]] = arith.constant 0 : i32
+ // CHECK: %[[BASE_H:.*]] = arith.constant 8 : i32
%srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32>
- // CHECK: %[[LD_PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64
- // CHECK: %[[LD_DESC_0:.*]] = vector.insert %[[LD_PTR_AS_I64]], %[[LD_CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
- // CHECK: %[[LD_DESC_1:.*]] = vector.bitcast %[[LD_DESC_0]] : vector<4xi64> to vector<8xi32>
- // CHECK: %[[LD_DESC_2:.*]] = vector.insert %[[PREF_BASE_W]], %[[LD_DESC_1]] [2] : i32 into vector<8xi32>
- // CHECK: %[[LD_DESC_3:.*]] = vector.insert %[[PREF_BASE_H]], %[[LD_DESC_2]] [3] : i32 into vector<8xi32>
- // CHECK: %[[LD_DESC_4:.*]] = vector.insert %[[OFFSET_ZERO]], %[[LD_DESC_3]] [4] : i32 into vector<8xi32>
- // CHECK: %[[LD_DESC:.*]] = vector.insert %[[OFFSET_ZERO]], %[[LD_DESC_4]] [5] : i32 into vector<8xi32>
%src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32,
#xegpu.block_tdesc_attr<memory_space = global>, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- //CHECK: %[[LD_DESC_I64:.*]] = vector.bitcast %[[LD_DESC]] : vector<8xi32> to vector<4xi64>
- //CHECK: %[[PREF_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64>
- //CHECK: %[[PREF_LLVMPTR:.*]] = llvm.inttoptr %[[PREF_INTPTR]] : i64 to !llvm.ptr<1>
- //CHECK: xevm.blockprefetch2d %[[PREF_LLVMPTR]], %[[PREF_BASE_ROW_IN_BYTES]], %[[PREF_BASE_H]],
- //CHECK-SAME: %[[PREF_BASE_ROW_IN_BYTES]], %[[OFFSET_ZERO]], %[[OFFSET_ZERO]]
+ //CHECK: %[[LLVMPTR:.*]] = llvm.inttoptr %{{.*}} : i64 to !llvm.ptr<1>
+ //CHECK: xevm.blockprefetch2d %[[LLVMPTR]], %[[BASE_WIDTH_PITCH_BYTES]], %[[BASE_H]],
+ //CHECK-SAME: %[[BASE_WIDTH_PITCH_BYTES]], %[[OFFSET_ZERO]], %[[OFFSET_ZERO]]
//CHECK-SAME: <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
//CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32, v_blocks = 1 : i32}>
//CHECK-SAME: : (!llvm.ptr<1>, i32, i32, i32, i32, i32)
More information about the Mlir-commits
mailing list